feat: split up handle and model_settings (#6022)
This commit is contained in:
committed by
Caren Thomas
parent
19e6b09da5
commit
0b1fe096ec
@@ -18932,6 +18932,30 @@
|
||||
"deprecated": true
|
||||
},
|
||||
"model": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "The model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"model_settings": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
@@ -18940,18 +18964,7 @@
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The model used by the agent."
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/EmbeddingModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The embedding model used by the agent."
|
||||
"description": "The model settings used by the agent."
|
||||
},
|
||||
"response_format": {
|
||||
"anyOf": [
|
||||
@@ -22943,15 +22956,12 @@
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "The model handle or model settings for the agent to use, specified either by a handle or an object. See the model schema for more information."
|
||||
"description": "The model handle for the agent to use (format: provider/model-name)."
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [
|
||||
@@ -22959,14 +22969,22 @@
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/EmbeddingModelSettings"
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"model_settings": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
"description": "The model settings for the agent."
|
||||
},
|
||||
"context_window_limit": {
|
||||
"anyOf": [
|
||||
@@ -24248,24 +24266,6 @@
|
||||
],
|
||||
"title": "EmbeddingModel"
|
||||
},
|
||||
"EmbeddingModelSettings": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "Model",
|
||||
"description": "The name of the model."
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"enum": ["openai", "ollama"],
|
||||
"title": "Provider",
|
||||
"description": "The provider of the model."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["model", "provider"],
|
||||
"title": "EmbeddingModelSettings"
|
||||
},
|
||||
"EventMessage": {
|
||||
"properties": {
|
||||
"id": {
|
||||
@@ -26794,15 +26794,12 @@
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "The model handle or model settings for the agent to use, specified either by a handle or an object. See the model schema for more information."
|
||||
"description": "The model handle for the agent to use (format: provider/model-name)."
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [
|
||||
@@ -26810,14 +26807,22 @@
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/EmbeddingModelSettings"
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"model_settings": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
"description": "The model settings for the agent."
|
||||
},
|
||||
"context_window_limit": {
|
||||
"anyOf": [
|
||||
@@ -30017,11 +30022,6 @@
|
||||
},
|
||||
"ModelSettings": {
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "Model",
|
||||
"description": "The name of the model."
|
||||
},
|
||||
"max_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Max Output Tokens",
|
||||
@@ -30030,7 +30030,6 @@
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["model"],
|
||||
"title": "ModelSettings",
|
||||
"description": "Schema for defining settings for a model"
|
||||
},
|
||||
@@ -35289,15 +35288,12 @@
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "The model used by the agent, specified either by a handle or an object. See the model schema for more information."
|
||||
"description": "The model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [
|
||||
@@ -35305,14 +35301,22 @@
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/EmbeddingModelSettings"
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"model_settings": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
"description": "The model settings for the agent."
|
||||
},
|
||||
"context_window_limit": {
|
||||
"anyOf": [
|
||||
@@ -36297,15 +36301,12 @@
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "The model handle or model settings for the agent to use, specified either by a handle or an object. See the model schema for more information."
|
||||
"description": "The model handle for the agent to use (format: provider/model-name)."
|
||||
},
|
||||
"embedding": {
|
||||
"anyOf": [
|
||||
@@ -36313,14 +36314,22 @@
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/EmbeddingModelSettings"
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding model handle used by the agent (format: provider/model-name)."
|
||||
},
|
||||
"model_settings": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelSettings"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Embedding",
|
||||
"description": "The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
"description": "The model settings for the agent."
|
||||
},
|
||||
"context_window_limit": {
|
||||
"anyOf": [
|
||||
|
||||
@@ -285,7 +285,9 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
if resolver:
|
||||
state[field_name] = resolver()
|
||||
|
||||
state["model"] = self.llm_config._to_model() if self.llm_config else None
|
||||
state["model"] = self.llm_config.handle if self.llm_config else None
|
||||
state["model_settings"] = self.llm_config._to_model_settings() if self.llm_config else None
|
||||
state["embedding"] = self.embedding_config.handle if self.embedding_config else None
|
||||
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
@@ -425,6 +427,8 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
state["managed_group"] = multi_agent_group
|
||||
state["tool_exec_environment_variables"] = tool_exec_environment_variables
|
||||
state["secrets"] = tool_exec_environment_variables
|
||||
state["model"] = self.llm_config._to_model() if self.llm_config else None
|
||||
state["model"] = self.llm_config.handle if self.llm_config else None
|
||||
state["model_settings"] = self.llm_config._to_model_settings() if self.llm_config else None
|
||||
state["embedding"] = self.embedding_config.handle if self.embedding_config else None
|
||||
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Dict, List, Literal, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING, DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.errors import AgentExportProcessingError
|
||||
from letta.errors import AgentExportProcessingError, LettaInvalidArgumentError
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
@@ -18,7 +18,7 @@ from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.model import EmbeddingModelSettings, ModelSettings
|
||||
from letta.schemas.model import ModelSettings
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.source import Source
|
||||
@@ -83,8 +83,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
embedding_config: EmbeddingConfig = Field(
|
||||
..., description="Deprecated: Use `embedding` field instead. The embedding configuration used by the agent.", deprecated=True
|
||||
)
|
||||
model: Optional[ModelSettings] = Field(None, description="The model used by the agent.")
|
||||
embedding: Optional[EmbeddingModelSettings] = Field(None, description="The embedding model used by the agent.")
|
||||
model: Optional[str] = Field(None, description="The model handle used by the agent (format: provider/model-name).")
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
model_settings: Optional[ModelSettings] = Field(None, description="The model settings used by the agent.")
|
||||
|
||||
response_format: Optional[ResponseFormatUnion] = Field(
|
||||
None,
|
||||
@@ -229,13 +230,12 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(
|
||||
None, description="Deprecated: Use `embedding` field instead. The embedding configuration used by the agent.", deprecated=True
|
||||
)
|
||||
model: Optional[str | ModelSettings] = Field( # TODO: make this required (breaking change)
|
||||
model: Optional[str] = Field( # TODO: make this required (breaking change)
|
||||
None,
|
||||
description="The model handle or model settings for the agent to use, specified either by a handle or an object. See the model schema for more information.",
|
||||
)
|
||||
embedding: Optional[str | EmbeddingModelSettings] = Field(
|
||||
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
description="The model handle for the agent to use (format: provider/model-name).",
|
||||
)
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
model_settings: Optional[ModelSettings] = Field(None, description="The model settings for the agent.")
|
||||
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
||||
embedding_chunk_size: Optional[int] = Field(
|
||||
@@ -348,9 +348,12 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
if not model:
|
||||
return model
|
||||
|
||||
if "/" not in model:
|
||||
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
|
||||
|
||||
provider_name, model_name = model.split("/", 1)
|
||||
if not provider_name or not model_name:
|
||||
raise ValueError("The llm config handle should be in the format provider/model-name")
|
||||
raise LettaInvalidArgumentError("The model handle should be in the format provider/model-name", argument_name="model")
|
||||
|
||||
return model
|
||||
|
||||
@@ -360,9 +363,12 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
if not embedding:
|
||||
return embedding
|
||||
|
||||
if "/" not in embedding:
|
||||
raise ValueError("The embedding handle should be in the format provider/model-name")
|
||||
|
||||
provider_name, embedding_name = embedding.split("/", 1)
|
||||
if not provider_name or not embedding_name:
|
||||
raise ValueError("The embedding config handle should be in the format provider/model-name")
|
||||
raise ValueError("The embedding handle should be in the format provider/model-name")
|
||||
|
||||
return embedding
|
||||
|
||||
@@ -410,13 +416,12 @@ class UpdateAgent(BaseModel):
|
||||
)
|
||||
|
||||
# model configuration
|
||||
model: Optional[str | ModelSettings] = Field(
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="The model used by the agent, specified either by a handle or an object. See the model schema for more information.",
|
||||
)
|
||||
embedding: Optional[str | EmbeddingModelSettings] = Field(
|
||||
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
description="The model handle used by the agent (format: provider/model-name).",
|
||||
)
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
model_settings: Optional[ModelSettings] = Field(None, description="The model settings for the agent.")
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
||||
reasoning: Optional[bool] = Field(
|
||||
None,
|
||||
|
||||
@@ -255,7 +255,7 @@ class LLMConfig(BaseModel):
|
||||
+ (f" [ip={self.model_endpoint}]" if self.model_endpoint else "")
|
||||
)
|
||||
|
||||
def _to_model(self) -> "ModelSettings":
|
||||
def _to_model_settings(self) -> "ModelSettings":
|
||||
"""
|
||||
Convert LLMConfig back into a Model schema (OpenAIModelSettings, AnthropicModelSettings, etc.).
|
||||
This is the inverse of the _to_legacy_config_params() methods in model.py.
|
||||
@@ -279,7 +279,6 @@ class LLMConfig(BaseModel):
|
||||
|
||||
if self.model_endpoint_type == "openai":
|
||||
return OpenAIModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
reasoning=OpenAIReasoning(reasoning_effort=self.reasoning_effort or "minimal"),
|
||||
@@ -287,7 +286,6 @@ class LLMConfig(BaseModel):
|
||||
elif self.model_endpoint_type == "anthropic":
|
||||
thinking_type = "enabled" if self.enable_reasoner else "disabled"
|
||||
return AnthropicModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
thinking=AnthropicThinking(type=thinking_type, budget_tokens=self.max_reasoning_tokens or 1024),
|
||||
@@ -295,7 +293,6 @@ class LLMConfig(BaseModel):
|
||||
)
|
||||
elif self.model_endpoint_type == "google_ai":
|
||||
return GoogleAIModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 65536,
|
||||
temperature=self.temperature,
|
||||
thinking_config=GeminiThinkingConfig(
|
||||
@@ -304,7 +301,6 @@ class LLMConfig(BaseModel):
|
||||
)
|
||||
elif self.model_endpoint_type == "google_vertex":
|
||||
return GoogleVertexModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 65536,
|
||||
temperature=self.temperature,
|
||||
thinking_config=GeminiThinkingConfig(
|
||||
@@ -313,39 +309,34 @@ class LLMConfig(BaseModel):
|
||||
)
|
||||
elif self.model_endpoint_type == "azure":
|
||||
return AzureModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "xai":
|
||||
return XAIModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "groq":
|
||||
return GroqModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "deepseek":
|
||||
return DeepseekModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "together":
|
||||
return TogetherModelSettings(
|
||||
model=self.model,
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "bedrock":
|
||||
return Model(model=self.model, max_output_tokens=self.max_tokens or 4096)
|
||||
return Model(max_output_tokens=self.max_tokens or 4096)
|
||||
else:
|
||||
# If we don't know the model type, use the default Model schema
|
||||
return Model(model=self.model, max_output_tokens=self.max_tokens or 4096)
|
||||
return Model(max_output_tokens=self.max_tokens or 4096)
|
||||
|
||||
@classmethod
|
||||
def is_openai_reasoning_model(cls, config: "LLMConfig") -> bool:
|
||||
|
||||
@@ -120,6 +120,25 @@ class Model(LLMConfig, ModelBase):
|
||||
provider_category=llm_config.provider_category,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_settings_schema(self) -> Optional[dict]:
|
||||
"""Returns the JSON schema for the ModelSettings class corresponding to this model's provider."""
|
||||
PROVIDER_SETTINGS_MAP = {
|
||||
ProviderType.openai: OpenAIModelSettings,
|
||||
ProviderType.anthropic: AnthropicModelSettings,
|
||||
ProviderType.google_ai: GoogleAIModelSettings,
|
||||
ProviderType.google_vertex: GoogleVertexModelSettings,
|
||||
ProviderType.azure: AzureModelSettings,
|
||||
ProviderType.xai: XAIModelSettings,
|
||||
ProviderType.groq: GroqModelSettings,
|
||||
ProviderType.deepseek: DeepseekModelSettings,
|
||||
ProviderType.together: TogetherModelSettings,
|
||||
ProviderType.bedrock: BedrockModelSettings,
|
||||
}
|
||||
|
||||
settings_class = PROVIDER_SETTINGS_MAP.get(self.provider_type)
|
||||
return settings_class.model_json_schema() if settings_class else None
|
||||
|
||||
|
||||
class EmbeddingModel(EmbeddingConfig, ModelBase):
|
||||
model_type: Literal["embedding"] = Field("embedding", description="Type of model (llm or embedding)")
|
||||
@@ -184,8 +203,9 @@ class EmbeddingModel(EmbeddingConfig, ModelBase):
|
||||
class ModelSettings(BaseModel):
|
||||
"""Schema for defining settings for a model"""
|
||||
|
||||
model: str = Field(..., description="The name of the model.")
|
||||
# model: str = Field(..., description="The name of the model.")
|
||||
max_output_tokens: int = Field(4096, description="The maximum number of tokens the model can generate.")
|
||||
parallel_tool_calls: bool = Field(False, description="Whether to enable parallel tool calling.")
|
||||
|
||||
|
||||
class OpenAIReasoning(BaseModel):
|
||||
@@ -201,7 +221,7 @@ class OpenAIReasoning(BaseModel):
|
||||
|
||||
class OpenAIModelSettings(ModelSettings):
|
||||
provider: Literal["openai"] = Field("openai", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
reasoning: OpenAIReasoning = Field(OpenAIReasoning(reasoning_effort="high"), description="The reasoning configuration for the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
@@ -220,6 +240,7 @@ class OpenAIModelSettings(ModelSettings):
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"reasoning_effort": self.reasoning.reasoning_effort,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -231,7 +252,7 @@ class OpenAIModelSettings(ModelSettings):
|
||||
|
||||
class AnthropicThinking(BaseModel):
|
||||
type: Literal["enabled", "disabled"] = Field("enabled", description="The type of thinking to use.")
|
||||
budget_tokens: int = Field(1024, description="The maximum number of tokens the model can use for extended thinking.")
|
||||
budget_tokens: int = Field(1024, ge=0, le=1024, description="The maximum number of tokens the model can use for extended thinking.")
|
||||
|
||||
|
||||
class AnthropicModelSettings(ModelSettings):
|
||||
@@ -258,17 +279,18 @@ class AnthropicModelSettings(ModelSettings):
|
||||
"extended_thinking": self.thinking.type == "enabled",
|
||||
"thinking_budget_tokens": self.thinking.budget_tokens,
|
||||
"verbosity": self.verbosity,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
include_thoughts: bool = Field(True, description="Whether to include thoughts in the model's response.")
|
||||
thinking_budget: int = Field(1024, description="The thinking budget for the model.")
|
||||
thinking_budget: int = Field(1024, ge=0, le=1024, description="The thinking budget for the model.")
|
||||
|
||||
|
||||
class GoogleAIModelSettings(ModelSettings):
|
||||
provider: Literal["google_ai"] = Field("google_ai", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
thinking_config: GeminiThinkingConfig = Field(
|
||||
GeminiThinkingConfig(include_thoughts=True, thinking_budget=1024), description="The thinking configuration for the model."
|
||||
)
|
||||
@@ -280,6 +302,7 @@ class GoogleAIModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"max_reasoning_tokens": self.thinking_config.thinking_budget if self.thinking_config.include_thoughts else 0,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -291,7 +314,7 @@ class AzureModelSettings(ModelSettings):
|
||||
"""Azure OpenAI model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider: Literal["azure"] = Field("azure", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
@@ -299,6 +322,7 @@ class AzureModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -306,7 +330,7 @@ class XAIModelSettings(ModelSettings):
|
||||
"""xAI model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider: Literal["xai"] = Field("xai", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
@@ -314,6 +338,7 @@ class XAIModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -321,7 +346,7 @@ class GroqModelSettings(ModelSettings):
|
||||
"""Groq model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider: Literal["groq"] = Field("groq", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
@@ -336,7 +361,7 @@ class DeepseekModelSettings(ModelSettings):
|
||||
"""Deepseek model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider: Literal["deepseek"] = Field("deepseek", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
@@ -344,6 +369,7 @@ class DeepseekModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -351,7 +377,7 @@ class TogetherModelSettings(ModelSettings):
|
||||
"""Together AI model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider: Literal["together"] = Field("together", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
@@ -359,6 +385,7 @@ class TogetherModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -366,7 +393,7 @@ class BedrockModelSettings(ModelSettings):
|
||||
"""AWS Bedrock model configuration."""
|
||||
|
||||
provider: Literal["bedrock"] = Field("bedrock", description="The provider of the model.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
temperature: float = Field(0.7, ge=0.0, le=1.0, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
@@ -374,6 +401,7 @@ class BedrockModelSettings(ModelSettings):
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
@@ -392,8 +420,3 @@ ModelSettingsUnion = Annotated[
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingModelSettings(BaseModel):
|
||||
model: str = Field(..., description="The name of the model.")
|
||||
provider: Literal["openai", "ollama"] = Field(..., description="The provider of the model.")
|
||||
|
||||
@@ -436,6 +436,8 @@ class SyncServer(object):
|
||||
handle = f"{request.model.provider}/{request.model.model}"
|
||||
# TODO: figure out how to override various params
|
||||
additional_config_params = request.model._to_legacy_config_params()
|
||||
additional_config_params["model"] = request.model.model
|
||||
additional_config_params["provider_name"] = request.model.provider
|
||||
|
||||
config_params = {
|
||||
"handle": handle,
|
||||
@@ -525,6 +527,11 @@ class SyncServer(object):
|
||||
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
|
||||
# update with model_settings
|
||||
if request.model_settings is not None:
|
||||
update_llm_config_params = request.model_settings._to_legacy_config_params()
|
||||
request.llm_config.update(update_llm_config_params)
|
||||
|
||||
# Copy parallel_tool_calls from request to llm_config if provided
|
||||
if request.parallel_tool_calls is not None:
|
||||
if request.llm_config is None:
|
||||
|
||||
@@ -6,5 +6,5 @@
|
||||
"model_wrapper": null,
|
||||
"put_inner_thoughts_in_kwargs": true,
|
||||
"enable_reasoner": true,
|
||||
"max_reasoning_tokens": 20000
|
||||
"max_reasoning_tokens": 1000
|
||||
}
|
||||
|
||||
@@ -1250,7 +1250,7 @@ async def test_agent_state_schema_unchanged(server: SyncServer):
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.model import EmbeddingModelSettings, ModelSettings
|
||||
from letta.schemas.model import ModelSettings
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -1271,9 +1271,10 @@ async def test_agent_state_schema_unchanged(server: SyncServer):
|
||||
"agent_type": AgentType,
|
||||
# LLM information
|
||||
"llm_config": LLMConfig,
|
||||
"model": ModelSettings,
|
||||
"embedding": EmbeddingModelSettings,
|
||||
"model": str,
|
||||
"embedding": str,
|
||||
"embedding_config": EmbeddingConfig,
|
||||
"model_settings": ModelSettings,
|
||||
"response_format": (ResponseFormatUnion, type(None)),
|
||||
# State fields
|
||||
"description": (str, type(None)),
|
||||
|
||||
Reference in New Issue
Block a user