feat: add chatgpt oauth client for codex routing (#8774)
* base * refresh * use default model fallback * patch * streaming * generate
This commit is contained in:
@@ -2,6 +2,9 @@ import json
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.interfaces.anthropic_parallel_tool_call_streaming_interface import SimpleAnthropicStreamingInterface
|
||||
from letta.interfaces.gemini_streaming_interface import SimpleGeminiStreamingInterface
|
||||
@@ -75,12 +78,22 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.deepseek, ProviderType.zai]:
|
||||
elif self.llm_config.model_endpoint_type in [
|
||||
ProviderType.openai,
|
||||
ProviderType.deepseek,
|
||||
ProviderType.zai,
|
||||
ProviderType.chatgpt_oauth,
|
||||
]:
|
||||
# Decide interface based on payload shape
|
||||
use_responses = "input" in request_data and "messages" not in request_data
|
||||
# No support for Responses API proxy
|
||||
is_proxy = self.llm_config.provider_name == "lmstudio_openai"
|
||||
|
||||
# ChatGPT OAuth always uses Responses API format
|
||||
if self.llm_config.model_endpoint_type == ProviderType.chatgpt_oauth:
|
||||
use_responses = True
|
||||
is_proxy = False
|
||||
|
||||
if use_responses and not is_proxy:
|
||||
self.interface = SimpleOpenAIResponsesStreamingInterface(
|
||||
is_openai_proxy=False,
|
||||
@@ -109,9 +122,6 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
||||
|
||||
# Extract optional parameters
|
||||
# ttft_span = kwargs.get('ttft_span', None)
|
||||
|
||||
# Start the streaming request (map provider errors to common LLMError types)
|
||||
try:
|
||||
# Gemini uses async generator pattern (no await) to maintain connection lifecycle
|
||||
|
||||
1036
letta/llm_api/chatgpt_oauth_client.py
Normal file
1036
letta/llm_api/chatgpt_oauth_client.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -100,6 +100,13 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.chatgpt_oauth:
|
||||
from letta.llm_api.chatgpt_oauth_client import ChatGPTOAuthClient
|
||||
|
||||
return ChatGPTOAuthClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case _:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ class ProviderType(str, Enum):
|
||||
azure = "azure"
|
||||
bedrock = "bedrock"
|
||||
cerebras = "cerebras"
|
||||
chatgpt_oauth = "chatgpt_oauth"
|
||||
deepseek = "deepseek"
|
||||
google_ai = "google_ai"
|
||||
google_vertex = "google_vertex"
|
||||
|
||||
@@ -49,6 +49,7 @@ class LLMConfig(BaseModel):
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
"chatgpt_oauth",
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||
@@ -308,6 +309,8 @@ class LLMConfig(BaseModel):
|
||||
AnthropicThinking,
|
||||
AzureModelSettings,
|
||||
BedrockModelSettings,
|
||||
ChatGPTOAuthModelSettings,
|
||||
ChatGPTOAuthReasoning,
|
||||
DeepseekModelSettings,
|
||||
GeminiThinkingConfig,
|
||||
GoogleAIModelSettings,
|
||||
@@ -382,7 +385,16 @@ class LLMConfig(BaseModel):
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "bedrock":
|
||||
return Model(max_output_tokens=self.max_tokens or 4096)
|
||||
return BedrockModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "chatgpt_oauth":
|
||||
return ChatGPTOAuthModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
reasoning=ChatGPTOAuthReasoning(reasoning_effort=self.reasoning_effort or "medium"),
|
||||
)
|
||||
else:
|
||||
# If we don't know the model type, use the default Model schema
|
||||
return Model(max_output_tokens=self.max_tokens or 4096)
|
||||
|
||||
@@ -48,6 +48,7 @@ class Model(LLMConfig, ModelBase):
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
"chatgpt_oauth",
|
||||
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
|
||||
context_window: int = Field(
|
||||
..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True
|
||||
@@ -434,6 +435,32 @@ class BedrockModelSettings(ModelSettings):
|
||||
}
|
||||
|
||||
|
||||
class ChatGPTOAuthReasoning(BaseModel):
|
||||
"""Reasoning configuration for ChatGPT OAuth models (GPT-5.x, o-series)."""
|
||||
|
||||
reasoning_effort: Literal["none", "low", "medium", "high", "xhigh"] = Field(
|
||||
"medium", description="The reasoning effort level for GPT-5.x and o-series models."
|
||||
)
|
||||
|
||||
|
||||
class ChatGPTOAuthModelSettings(ModelSettings):
|
||||
"""ChatGPT OAuth model configuration (uses ChatGPT backend API)."""
|
||||
|
||||
provider_type: Literal[ProviderType.chatgpt_oauth] = Field(ProviderType.chatgpt_oauth, description="The type of the provider.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
reasoning: ChatGPTOAuthReasoning = Field(
|
||||
ChatGPTOAuthReasoning(reasoning_effort="medium"), description="The reasoning configuration for the model."
|
||||
)
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"reasoning_effort": self.reasoning.reasoning_effort,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
ModelSettingsUnion = Annotated[
|
||||
Union[
|
||||
OpenAIModelSettings,
|
||||
@@ -447,6 +474,7 @@ ModelSettingsUnion = Annotated[
|
||||
DeepseekModelSettings,
|
||||
TogetherModelSettings,
|
||||
BedrockModelSettings,
|
||||
ChatGPTOAuthModelSettings,
|
||||
],
|
||||
Field(discriminator="provider_type"),
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ from .azure import AzureProvider
|
||||
from .base import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from .bedrock import BedrockProvider
|
||||
from .cerebras import CerebrasProvider
|
||||
from .chatgpt_oauth import ChatGPTOAuthProvider
|
||||
from .deepseek import DeepSeekProvider
|
||||
from .google_gemini import GoogleAIProvider
|
||||
from .google_vertex import GoogleVertexProvider
|
||||
@@ -31,7 +32,8 @@ __all__ = [
|
||||
"AnthropicProvider",
|
||||
"AzureProvider",
|
||||
"BedrockProvider",
|
||||
"CerebrasProvider", # NEW
|
||||
"CerebrasProvider",
|
||||
"ChatGPTOAuthProvider",
|
||||
"DeepSeekProvider",
|
||||
"GoogleAIProvider",
|
||||
"GoogleVertexProvider",
|
||||
|
||||
@@ -184,6 +184,7 @@ class Provider(ProviderBase):
|
||||
AzureProvider,
|
||||
BedrockProvider,
|
||||
CerebrasProvider,
|
||||
ChatGPTOAuthProvider,
|
||||
DeepSeekProvider,
|
||||
GoogleAIProvider,
|
||||
GoogleVertexProvider,
|
||||
@@ -229,6 +230,8 @@ class Provider(ProviderBase):
|
||||
return DeepSeekProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.cerebras:
|
||||
return CerebrasProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.chatgpt_oauth:
|
||||
return ChatGPTOAuthProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.xai:
|
||||
return XAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.zai:
|
||||
|
||||
366
letta/schemas/providers/chatgpt_oauth.py
Normal file
366
letta/schemas/providers/chatgpt_oauth.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""ChatGPT OAuth Provider - uses chatgpt.com/backend-api/codex with OAuth authentication."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.base import Provider
|
||||
from letta.schemas.secret import Secret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ChatGPT Backend API Configuration
|
||||
CHATGPT_CODEX_ENDPOINT = "https://chatgpt.com/backend-api/codex/responses"
|
||||
CHATGPT_TOKEN_REFRESH_URL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
# OAuth client_id for Codex CLI (required for token refresh)
|
||||
# Must match the client_id used in the initial OAuth authorization flow
|
||||
# This is the public client_id used by Codex CLI / Letta Code
|
||||
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
# Token refresh buffer (refresh 5 minutes before expiry)
|
||||
TOKEN_REFRESH_BUFFER_SECONDS = 300
|
||||
|
||||
# Hardcoded models available via ChatGPT backend
|
||||
# These are models that can be accessed through ChatGPT Plus/Pro subscriptions
|
||||
# Model list based on opencode-openai-codex-auth plugin presets
|
||||
# Reasoning effort levels are configured via llm_config.reasoning_effort
|
||||
CHATGPT_MODELS = [
|
||||
# GPT-5.2 models (supports none/low/medium/high/xhigh reasoning)
|
||||
{"name": "gpt-5.2", "context_window": 272000},
|
||||
{"name": "gpt-5.2-codex", "context_window": 272000},
|
||||
# GPT-5.1 models
|
||||
{"name": "gpt-5.1", "context_window": 272000},
|
||||
{"name": "gpt-5.1-codex", "context_window": 272000},
|
||||
{"name": "gpt-5.1-codex-mini", "context_window": 272000},
|
||||
{"name": "gpt-5.1-codex-max", "context_window": 272000},
|
||||
# GPT-5 Codex models (original)
|
||||
{"name": "gpt-5-codex-mini", "context_window": 272000},
|
||||
# GPT-4 models (for ChatGPT Plus users)
|
||||
{"name": "gpt-4o", "context_window": 128000},
|
||||
{"name": "gpt-4o-mini", "context_window": 128000},
|
||||
{"name": "o1", "context_window": 200000},
|
||||
{"name": "o1-pro", "context_window": 200000},
|
||||
{"name": "o3", "context_window": 200000},
|
||||
{"name": "o3-mini", "context_window": 200000},
|
||||
{"name": "o4-mini", "context_window": 200000},
|
||||
]
|
||||
|
||||
|
||||
class ChatGPTOAuthCredentials(BaseModel):
|
||||
"""OAuth credentials for ChatGPT backend API access.
|
||||
|
||||
These credentials are stored as JSON in the provider's api_key_enc field.
|
||||
"""
|
||||
|
||||
access_token: str = Field(..., description="OAuth access token for ChatGPT API")
|
||||
refresh_token: str = Field(..., description="OAuth refresh token for obtaining new access tokens")
|
||||
account_id: str = Field(..., description="ChatGPT account ID for the ChatGPT-Account-Id header")
|
||||
expires_at: int = Field(..., description="Unix timestamp when the access_token expires")
|
||||
|
||||
def is_expired(self, buffer_seconds: int = TOKEN_REFRESH_BUFFER_SECONDS) -> bool:
|
||||
"""Check if token is expired or will expire within buffer_seconds.
|
||||
|
||||
Handles both seconds and milliseconds timestamps (auto-detects based on magnitude).
|
||||
"""
|
||||
expires_at = self.expires_at
|
||||
# Auto-detect milliseconds (13+ digits) vs seconds (10 digits)
|
||||
# Timestamps > 10^12 are definitely milliseconds (year 33658 in seconds)
|
||||
if expires_at > 10**12:
|
||||
expires_at = expires_at // 1000 # Convert ms to seconds
|
||||
|
||||
current_time = datetime.utcnow().timestamp()
|
||||
is_expired = current_time >= (expires_at - buffer_seconds)
|
||||
logger.debug(f"Token expiry check: current={current_time}, expires_at={expires_at}, buffer={buffer_seconds}, expired={is_expired}")
|
||||
return is_expired
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Serialize to JSON string for storage in api_key_enc."""
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ChatGPTOAuthCredentials":
|
||||
"""Deserialize from JSON string stored in api_key_enc."""
|
||||
data = json.loads(json_str)
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class ChatGPTOAuthProvider(Provider):
|
||||
"""
|
||||
ChatGPT OAuth Provider for accessing ChatGPT's backend-api with OAuth tokens.
|
||||
|
||||
This provider enables using ChatGPT Plus/Pro subscription credentials to access
|
||||
OpenAI models through the ChatGPT backend API at chatgpt.com/backend-api/codex.
|
||||
|
||||
OAuth credentials are stored as JSON in the api_key_enc field:
|
||||
{
|
||||
"access_token": "...",
|
||||
"refresh_token": "...",
|
||||
"account_id": "...",
|
||||
"expires_at": 1234567890
|
||||
}
|
||||
|
||||
The client (e.g., Letta Code) performs the OAuth flow and sends the credentials
|
||||
to the backend via the provider creation API.
|
||||
"""
|
||||
|
||||
provider_type: Literal[ProviderType.chatgpt_oauth] = Field(
|
||||
ProviderType.chatgpt_oauth,
|
||||
description="The type of the provider.",
|
||||
)
|
||||
provider_category: ProviderCategory = Field(
|
||||
ProviderCategory.byok, # Always BYOK since it uses user's OAuth credentials
|
||||
description="The category of the provider (always byok for OAuth)",
|
||||
)
|
||||
base_url: str = Field(
|
||||
CHATGPT_CODEX_ENDPOINT,
|
||||
description="Base URL for the ChatGPT backend API.",
|
||||
)
|
||||
|
||||
async def get_oauth_credentials(self) -> Optional[ChatGPTOAuthCredentials]:
|
||||
"""Retrieve and parse OAuth credentials from api_key_enc.
|
||||
|
||||
Returns:
|
||||
ChatGPTOAuthCredentials if valid credentials exist, None otherwise.
|
||||
"""
|
||||
if not self.api_key_enc:
|
||||
return None
|
||||
|
||||
json_str = await self.api_key_enc.get_plaintext_async()
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
return ChatGPTOAuthCredentials.from_json(json_str)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse ChatGPT OAuth credentials: {e}")
|
||||
return None
|
||||
|
||||
async def refresh_token_if_needed(
|
||||
self, actor: Optional["User"] = None, force_refresh: bool = False
|
||||
) -> Optional[ChatGPTOAuthCredentials]:
|
||||
"""Check if token needs refresh and refresh if necessary.
|
||||
|
||||
This method is called before each API request to ensure valid credentials.
|
||||
Tokens are refreshed 5 minutes before expiry to avoid edge cases.
|
||||
|
||||
Args:
|
||||
actor: The user performing the action. Required for persisting refreshed credentials.
|
||||
force_refresh: If True, always refresh the token regardless of expiry. For testing only.
|
||||
|
||||
Returns:
|
||||
Updated credentials if successful, None on failure.
|
||||
"""
|
||||
creds = await self.get_oauth_credentials()
|
||||
if not creds:
|
||||
return None
|
||||
|
||||
if not creds.is_expired() and not force_refresh:
|
||||
return creds
|
||||
|
||||
# Token needs refresh
|
||||
logger.debug(f"ChatGPT OAuth token refresh triggered (expired={creds.is_expired()}, force={force_refresh})")
|
||||
|
||||
try:
|
||||
new_creds = await self._perform_token_refresh(creds)
|
||||
# Update stored credentials in memory and persist to database
|
||||
await self._update_stored_credentials(new_creds, actor=actor)
|
||||
return new_creds
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh ChatGPT OAuth token: {e}")
|
||||
# If refresh fails but original access_token is still valid, use it
|
||||
if not creds.is_expired():
|
||||
logger.warning("Token refresh failed, but original access_token is still valid - using existing token")
|
||||
return creds
|
||||
# Both refresh failed AND token is expired - return None to trigger auth error
|
||||
return None
|
||||
|
||||
async def _perform_token_refresh(self, creds: ChatGPTOAuthCredentials) -> ChatGPTOAuthCredentials:
|
||||
"""Perform OAuth token refresh with OpenAI's token endpoint.
|
||||
|
||||
Args:
|
||||
creds: Current credentials containing the refresh_token.
|
||||
|
||||
Returns:
|
||||
New ChatGPTOAuthCredentials with refreshed access_token.
|
||||
|
||||
Raises:
|
||||
LLMAuthenticationError: If refresh fails due to invalid credentials.
|
||||
LLMError: If refresh fails due to network or server error.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
CHATGPT_TOKEN_REFRESH_URL,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": creds.refresh_token,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Calculate new expiry time
|
||||
expires_in = data.get("expires_in", 3600)
|
||||
new_expires_at = int(datetime.utcnow().timestamp()) + expires_in
|
||||
|
||||
new_access_token = data["access_token"]
|
||||
new_refresh_token = data.get("refresh_token", creds.refresh_token)
|
||||
|
||||
logger.debug(f"ChatGPT OAuth token refreshed, expires_in={expires_in}s")
|
||||
|
||||
return ChatGPTOAuthCredentials(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
account_id=creds.account_id, # Account ID doesn't change
|
||||
expires_at=new_expires_at,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Log full error details for debugging
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
logger.error(f"Token refresh HTTP error: {e.response.status_code} - JSON: {error_body}")
|
||||
except Exception:
|
||||
logger.error(f"Token refresh HTTP error: {e.response.status_code} - Text: {e.response.text}")
|
||||
if e.response.status_code == 401:
|
||||
raise LLMAuthenticationError(
|
||||
message="Failed to refresh ChatGPT OAuth token: refresh token is invalid or expired",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
)
|
||||
raise LLMError(
|
||||
message=f"Failed to refresh ChatGPT OAuth token: {e}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh error: {type(e).__name__}: {e}")
|
||||
raise LLMError(
|
||||
message=f"Failed to refresh ChatGPT OAuth token: {e}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
async def _update_stored_credentials(self, creds: ChatGPTOAuthCredentials, actor: Optional["User"] = None) -> None:
|
||||
"""Update stored credentials in memory and persist to database.
|
||||
|
||||
Args:
|
||||
creds: New credentials to store.
|
||||
actor: The user performing the action. Required for database persistence.
|
||||
"""
|
||||
new_secret = await Secret.from_plaintext_async(creds.to_json())
|
||||
# Update in-memory value
|
||||
object.__setattr__(self, "api_key_enc", new_secret)
|
||||
|
||||
# Persist to database if we have an actor and provider ID
|
||||
if actor and self.id:
|
||||
try:
|
||||
from letta.schemas.providers.base import ProviderUpdate
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
await provider_manager.update_provider_async(
|
||||
provider_id=self.id,
|
||||
provider_update=ProviderUpdate(api_key=creds.to_json()),
|
||||
actor=actor,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist refreshed credentials to database: {e}")
|
||||
# Don't fail the request - we have valid credentials in memory
|
||||
|
||||
async def check_api_key(self):
|
||||
"""Validate the OAuth credentials by checking token validity.
|
||||
|
||||
Raises:
|
||||
ValueError: If no credentials are configured.
|
||||
LLMAuthenticationError: If credentials are invalid.
|
||||
"""
|
||||
creds = await self.get_oauth_credentials()
|
||||
if not creds:
|
||||
raise ValueError("No ChatGPT OAuth credentials configured")
|
||||
|
||||
# Try to refresh if needed
|
||||
creds = await self.refresh_token_if_needed()
|
||||
if not creds:
|
||||
raise LLMAuthenticationError(
|
||||
message="Failed to obtain valid ChatGPT OAuth credentials",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
)
|
||||
|
||||
# Optionally make a test request to validate
|
||||
# For now, we just verify we have valid-looking credentials
|
||||
if not creds.access_token or not creds.account_id:
|
||||
raise LLMAuthenticationError(
|
||||
message="ChatGPT OAuth credentials are incomplete",
|
||||
code=ErrorCode.UNAUTHENTICATED,
|
||||
)
|
||||
|
||||
def get_default_max_output_tokens(self, model_name: str) -> int:
|
||||
"""Get the default max output tokens for ChatGPT models."""
|
||||
# Reasoning models (o-series) have higher limits
|
||||
if model_name.startswith("o1") or model_name.startswith("o3") or model_name.startswith("o4"):
|
||||
return 100000
|
||||
# GPT-5.x models
|
||||
elif "gpt-5" in model_name:
|
||||
return 16384
|
||||
# GPT-4 models
|
||||
elif "gpt-4" in model_name:
|
||||
return 16384
|
||||
return 4096
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
"""List available models from ChatGPT backend.
|
||||
|
||||
Returns a hardcoded list of models available via ChatGPT Plus/Pro subscriptions.
|
||||
"""
|
||||
creds = await self.get_oauth_credentials()
|
||||
if not creds:
|
||||
logger.warning("Cannot list models: no valid ChatGPT OAuth credentials")
|
||||
return []
|
||||
|
||||
configs = []
|
||||
for model in CHATGPT_MODELS:
|
||||
model_name = model["name"]
|
||||
context_window = model["context_window"]
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="chatgpt_oauth",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
async def list_embedding_models_async(self) -> list:
|
||||
"""ChatGPT backend does not support embedding models."""
|
||||
return []
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
"""Get the context window for a model."""
|
||||
for model in CHATGPT_MODELS:
|
||||
if model["name"] == model_name:
|
||||
return model["context_window"]
|
||||
return 128000 # Default
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> int | None:
|
||||
"""Get the context window for a model (async version)."""
|
||||
return self.get_model_context_window(model_name)
|
||||
@@ -1542,6 +1542,7 @@ async def send_message(
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"chatgpt_oauth",
|
||||
]
|
||||
|
||||
# Create a new run for execution tracking
|
||||
@@ -2126,6 +2127,7 @@ async def preview_model_request(
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"chatgpt_oauth",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
@@ -2180,6 +2182,7 @@ async def summarize_messages(
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"chatgpt_oauth",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
|
||||
@@ -1693,7 +1693,7 @@ class SyncServer(object):
|
||||
# TODO: cleanup this logic
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
# supports_token_streaming = ["openai", "anthropic", "xai", "deepseek"]
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"] # TODO re-enable xAI once streaming is patched
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek", "chatgpt_oauth"] # TODO re-enable xAI once streaming is patched
|
||||
if stream_tokens and (llm_config.model_endpoint_type not in supports_token_streaming):
|
||||
logger.warning(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
@@ -1825,7 +1825,7 @@ class SyncServer(object):
|
||||
letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor)
|
||||
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek", "chatgpt_oauth"]
|
||||
if stream_tokens and (llm_config.model_endpoint_type not in supports_token_streaming):
|
||||
logger.warning(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
|
||||
@@ -482,6 +482,7 @@ class ProviderManager:
|
||||
|
||||
try:
|
||||
# Get the provider class and create an instance
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers.anthropic import AnthropicProvider
|
||||
from letta.schemas.providers.azure import AzureProvider
|
||||
from letta.schemas.providers.bedrock import BedrockProvider
|
||||
@@ -491,42 +492,47 @@ class ProviderManager:
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
from letta.schemas.providers.zai import ZAIProvider
|
||||
|
||||
provider_type_to_class = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"groq": GroqProvider,
|
||||
"google": GoogleAIProvider,
|
||||
"ollama": OllamaProvider,
|
||||
"bedrock": BedrockProvider,
|
||||
"azure": AzureProvider,
|
||||
"zai": ZAIProvider,
|
||||
}
|
||||
# ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id
|
||||
# (needed for OAuth token refresh and database persistence)
|
||||
if provider.provider_type == ProviderType.chatgpt_oauth:
|
||||
provider_instance = provider.cast_to_subtype()
|
||||
else:
|
||||
provider_type_to_class = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
"groq": GroqProvider,
|
||||
"google": GoogleAIProvider,
|
||||
"ollama": OllamaProvider,
|
||||
"bedrock": BedrockProvider,
|
||||
"azure": AzureProvider,
|
||||
"zai": ZAIProvider,
|
||||
}
|
||||
|
||||
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
|
||||
provider_class = provider_type_to_class.get(provider_type)
|
||||
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
|
||||
provider_class = provider_type_to_class.get(provider_type)
|
||||
|
||||
if not provider_class:
|
||||
logger.warning(f"No provider class found for type '{provider_type}'")
|
||||
return
|
||||
if not provider_class:
|
||||
logger.warning(f"No provider class found for type '{provider_type}'")
|
||||
return
|
||||
|
||||
# Create provider instance with necessary parameters
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
kwargs = {
|
||||
"name": provider.name,
|
||||
"api_key": api_key,
|
||||
"provider_category": provider.provider_category,
|
||||
}
|
||||
if provider.base_url:
|
||||
kwargs["base_url"] = provider.base_url
|
||||
if access_key:
|
||||
kwargs["access_key"] = access_key
|
||||
if provider.region:
|
||||
kwargs["region"] = provider.region
|
||||
if provider.api_version:
|
||||
kwargs["api_version"] = provider.api_version
|
||||
# Create provider instance with necessary parameters
|
||||
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
|
||||
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
|
||||
kwargs = {
|
||||
"name": provider.name,
|
||||
"api_key": api_key,
|
||||
"provider_category": provider.provider_category,
|
||||
}
|
||||
if provider.base_url:
|
||||
kwargs["base_url"] = provider.base_url
|
||||
if access_key:
|
||||
kwargs["access_key"] = access_key
|
||||
if provider.region:
|
||||
kwargs["region"] = provider.region
|
||||
if provider.api_version:
|
||||
kwargs["api_version"] = provider.api_version
|
||||
|
||||
provider_instance = provider_class(**kwargs)
|
||||
provider_instance = provider_class(**kwargs)
|
||||
|
||||
# Query the provider's API for available models
|
||||
llm_models = await provider_instance.list_llm_models_async()
|
||||
|
||||
@@ -493,11 +493,12 @@ class StreamingService:
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"chatgpt_oauth",
|
||||
]
|
||||
|
||||
def _is_token_streaming_compatible(self, agent: AgentState) -> bool:
|
||||
"""Check if agent's model supports token-level streaming."""
|
||||
base_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek", "zai"]
|
||||
base_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek", "zai", "chatgpt_oauth"]
|
||||
google_letta_v1 = agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in [
|
||||
"google_ai",
|
||||
"google_vertex",
|
||||
|
||||
Reference in New Issue
Block a user