diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 0d3f7974..ff3eca31 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -2,6 +2,9 @@ import json from typing import AsyncGenerator, List from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter +from letta.log import get_logger + +logger = get_logger(__name__) from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.interfaces.anthropic_parallel_tool_call_streaming_interface import SimpleAnthropicStreamingInterface from letta.interfaces.gemini_streaming_interface import SimpleGeminiStreamingInterface @@ -75,12 +78,22 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): run_id=self.run_id, step_id=step_id, ) - elif self.llm_config.model_endpoint_type in [ProviderType.openai, ProviderType.deepseek, ProviderType.zai]: + elif self.llm_config.model_endpoint_type in [ + ProviderType.openai, + ProviderType.deepseek, + ProviderType.zai, + ProviderType.chatgpt_oauth, + ]: # Decide interface based on payload shape use_responses = "input" in request_data and "messages" not in request_data # No support for Responses API proxy is_proxy = self.llm_config.provider_name == "lmstudio_openai" + # ChatGPT OAuth always uses Responses API format + if self.llm_config.model_endpoint_type == ProviderType.chatgpt_oauth: + use_responses = True + is_proxy = False + if use_responses and not is_proxy: self.interface = SimpleOpenAIResponsesStreamingInterface( is_openai_proxy=False, @@ -109,9 +122,6 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): else: raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}") - # Extract optional parameters - # ttft_span = kwargs.get('ttft_span', None) - # Start the streaming request (map provider errors to common LLMError types) try: # Gemini uses async generator pattern (no await) to maintain connection lifecycle diff --git a/letta/llm_api/chatgpt_oauth_client.py b/letta/llm_api/chatgpt_oauth_client.py new file mode 100644 index 00000000..6819ac95 --- /dev/null +++ b/letta/llm_api/chatgpt_oauth_client.py @@ -0,0 +1,1036 @@ +"""ChatGPT OAuth Client - handles requests to chatgpt.com/backend-api/codex/responses.""" + +import json +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union + +import httpx +from openai import AsyncStream +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseReasoningSummaryTextDeltaEvent, + ResponseReasoningSummaryTextDoneEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, +) +from openai.types.responses.response_stream_event import ResponseStreamEvent + +from letta.errors import ( + ContextWindowExceededError, + ErrorCode, + LLMAuthenticationError, + LLMBadRequestError, + LLMConnectionError, + LLMRateLimitError, + LLMServerError, + LLMTimeoutError, +) +from letta.llm_api.llm_client_base import LLMClientBase +from letta.log import get_logger +from letta.otel.tracing import trace_method +from letta.schemas.enums import AgentType, ProviderCategory +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage +from letta.schemas.openai.chat_completion_response import ( + ChatCompletionResponse, + Choice, + FunctionCall, + Message as ChoiceMessage, + ToolCall, + UsageStatistics, +) +from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider + +logger = get_logger(__name__) + +# ChatGPT Backend API endpoint +CHATGPT_CODEX_ENDPOINT = "https://chatgpt.com/backend-api/codex/responses" + + +class AsyncStreamWrapper: + """Wraps an async generator to provide async context manager protocol. + + The OpenAI SDK's AsyncStream implements __aenter__ and __aexit__, + but our custom SSE handler returns a raw async generator. This wrapper + provides the context manager protocol so it can be used with 'async with'. + """ + + def __init__(self, generator: AsyncIterator[ResponseStreamEvent]): + self._generator = generator + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Close the generator if it has an aclose method + if hasattr(self._generator, "aclose"): + await self._generator.aclose() + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> ResponseStreamEvent: + return await self._generator.__anext__() + + +class ChatGPTOAuthClient(LLMClientBase): + """ + LLM client for ChatGPT OAuth provider. + + This client: + 1. Transforms standard OpenAI chat format to ChatGPT backend Responses API format + 2. Adds required headers (Authorization, ChatGPT-Account-Id, OpenAI-Beta, OpenAI-Originator) + 3. Makes requests to chatgpt.com/backend-api/codex/responses + 4. Transforms responses back to OpenAI ChatCompletion format + """ + + @trace_method + async def _get_provider_and_credentials_async(self, llm_config: LLMConfig) -> tuple[ChatGPTOAuthProvider, ChatGPTOAuthCredentials]: + """Get the ChatGPT OAuth provider and credentials with automatic refresh if needed. + + Args: + llm_config: The LLM configuration containing provider info. + + Returns: + Tuple of (provider, credentials). + + Raises: + LLMAuthenticationError: If credentials cannot be obtained. + """ + from letta.services.provider_manager import ProviderManager + + if llm_config.provider_category != ProviderCategory.byok: + raise ValueError("ChatGPT OAuth requires BYOK provider credentials") + + # Get provider + provider_manager = ProviderManager() + providers = await provider_manager.list_providers_async( + name=llm_config.provider_name, + actor=self.actor, + provider_category=[ProviderCategory.byok], + ) + + if not providers: + raise LLMAuthenticationError( + message=f"ChatGPT OAuth provider '{llm_config.provider_name}' not found", + code=ErrorCode.UNAUTHENTICATED, + ) + + provider: ChatGPTOAuthProvider = providers[0].cast_to_subtype() + + # Get credentials with automatic refresh (pass actor for persistence) + creds = await provider.refresh_token_if_needed(actor=self.actor) + if not creds: + raise LLMAuthenticationError( + message="Failed to obtain valid ChatGPT OAuth credentials", + code=ErrorCode.UNAUTHENTICATED, + ) + + return provider, creds + + def _build_headers(self, creds: ChatGPTOAuthCredentials) -> Dict[str, str]: + """Build required headers for ChatGPT backend API. + + Args: + creds: OAuth credentials containing access_token and account_id. + + Returns: + Dictionary of HTTP headers. + """ + return { + "Authorization": f"Bearer {creds.access_token}", + "ChatGPT-Account-Id": creds.account_id, + "OpenAI-Beta": "responses=v1", + "OpenAI-Originator": "codex", + "Content-Type": "application/json", + "accept": "text/event-stream", + } + + @trace_method + def build_request_data( + self, + agent_type: AgentType, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + requires_subsequent_tool_call: bool = False, + tool_return_truncation_chars: Optional[int] = None, + ) -> dict: + """ + Build request data for ChatGPT backend API in Responses API format. + + The ChatGPT backend uses the OpenAI Responses API format: + - `input` array instead of `messages` + - `role: "developer"` instead of `role: "system"` + - Structured content arrays + """ + # Use the existing method to convert messages to Responses API format + input_messages = PydanticMessage.to_openai_responses_dicts_from_list( + messages, + tool_return_truncation_chars=tool_return_truncation_chars, + ) + + # Extract system message as instructions + instructions = None + filtered_input = [] + for msg in input_messages: + if msg.get("role") == "developer": + # First developer message becomes instructions + if instructions is None: + content = msg.get("content", []) + if isinstance(content, list) and content: + instructions = content[0].get("text", "") + elif isinstance(content, str): + instructions = content + else: + filtered_input.append(msg) + else: + filtered_input.append(msg) + + # Build tool_choice + tool_choice = None + if tools: + if force_tool_call is not None: + tool_choice = {"type": "function", "name": force_tool_call} + elif requires_subsequent_tool_call: + tool_choice = "required" + else: + tool_choice = "auto" + + # Build request payload for ChatGPT backend + data: Dict[str, Any] = { + "model": llm_config.model, + "input": filtered_input, + "store": False, # Required for stateless operation + "stream": True, # ChatGPT backend requires streaming + } + + if instructions: + data["instructions"] = instructions + + if tools: + # Convert tools to Responses API format + responses_tools = [ + { + "type": "function", + "name": t.get("name"), + "description": t.get("description"), + "parameters": t.get("parameters"), + } + for t in tools + ] + data["tools"] = responses_tools + data["tool_choice"] = tool_choice + + # Note: ChatGPT backend does NOT support max_output_tokens parameter + + # Add reasoning effort for reasoning models (GPT-5.x, o-series) + if self.is_reasoning_model(llm_config) and llm_config.reasoning_effort: + data["reasoning"] = { + "effort": llm_config.reasoning_effort, + "summary": "detailed", + } + + return data + + def _transform_response_from_chatgpt_backend(self, response_data: dict) -> dict: + """Transform ChatGPT backend response to standard OpenAI ChatCompletion format. + + The ChatGPT backend returns responses in Responses API format. + This method normalizes them to ChatCompletion format. + + Args: + response_data: Raw response from ChatGPT backend. + + Returns: + Response in OpenAI ChatCompletion format. + """ + # If response is already in ChatCompletion format, return as-is + if "choices" in response_data: + return response_data + + # Extract from Responses API format + output = response_data.get("output", []) + message_content = "" + tool_calls = None + reasoning_content = "" + + for item in output: + item_type = item.get("type") + + if item_type == "message": + content_parts = item.get("content", []) + for part in content_parts: + if part.get("type") in ("output_text", "text"): + message_content += part.get("text", "") + elif part.get("type") == "refusal": + message_content += part.get("refusal", "") + + elif item_type == "function_call": + if tool_calls is None: + tool_calls = [] + tool_calls.append( + { + "id": item.get("call_id", item.get("id", "")), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ) + + elif item_type == "reasoning": + # Capture reasoning/thinking content if present + summary = item.get("summary", []) + for s in summary: + if s.get("type") == "summary_text": + reasoning_content += s.get("text", "") + + # Build ChatCompletion response + finish_reason = "stop" + if tool_calls: + finish_reason = "tool_calls" + + transformed = { + "id": response_data.get("id", "chatgpt-response"), + "object": "chat.completion", + "created": response_data.get("created_at", 0), + "model": response_data.get("model", ""), + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message_content or None, + "tool_calls": tool_calls, + }, + "finish_reason": finish_reason, + } + ], + "usage": self._transform_usage(response_data.get("usage", {})), + } + + return transformed + + def _transform_usage(self, usage: dict) -> dict: + """Transform usage statistics from Responses API format.""" + return { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0), + } + + @trace_method + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """Synchronous request - not recommended for ChatGPT OAuth.""" + import asyncio + + return asyncio.run(self.request_async(request_data, llm_config)) + + @trace_method + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: + """Make asynchronous request to ChatGPT backend API. + + Args: + request_data: Request payload in Responses API format. + llm_config: LLM configuration. + + Returns: + Response data in OpenAI ChatCompletion format. + """ + logger.info("ChatGPT OAuth request_async called (non-streaming path)") + _, creds = await self._get_provider_and_credentials_async(llm_config) + headers = self._build_headers(creds) + + endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT + + # ChatGPT backend requires streaming, so we use client.stream() to handle SSE + async with httpx.AsyncClient() as client: + try: + async with client.stream( + "POST", + endpoint, + json=request_data, + headers=headers, + timeout=120.0, + ) as response: + response.raise_for_status() + # Accumulate SSE events into a final response + return await self._accumulate_sse_response(response) + + except httpx.HTTPStatusError as e: + raise self._handle_http_error(e) + except httpx.TimeoutException: + raise LLMTimeoutError( + message="ChatGPT backend request timed out", + code=ErrorCode.TIMEOUT, + ) + except httpx.RequestError as e: + raise LLMConnectionError( + message=f"Failed to connect to ChatGPT backend: {str(e)}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) + + async def _accumulate_sse_response(self, response: httpx.Response) -> dict: + """Accumulate SSE stream into a final response. + + ChatGPT backend may return SSE even for non-streaming requests. + This method accumulates all events into a single response. + + Args: + response: httpx Response object with SSE content. + + Returns: + Accumulated response data. + """ + accumulated_content = "" + accumulated_tool_calls: List[Dict[str, Any]] = [] + model = "" + response_id = "" + usage = {} + + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + break + + try: + event = json.loads(data_str) + except json.JSONDecodeError: + continue + + # Extract response metadata + if not response_id and event.get("id"): + response_id = event["id"] + if not model and event.get("model"): + model = event["model"] + if event.get("usage"): + usage = event["usage"] + + # Handle different event types + event_type = event.get("type") + + if event_type == "response.output_item.done": + item = event.get("item", {}) + item_type = item.get("type") + + if item_type == "message": + for content in item.get("content", []): + if content.get("type") in ("output_text", "text"): + accumulated_content += content.get("text", "") + + elif item_type == "function_call": + accumulated_tool_calls.append( + { + "id": item.get("call_id", item.get("id", "")), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ) + + elif event_type == "response.content_part.delta": + delta = event.get("delta", {}) + if delta.get("type") == "text_delta": + accumulated_content += delta.get("text", "") + + elif event_type == "response.done": + # Final response event + if event.get("response", {}).get("usage"): + usage = event["response"]["usage"] + + # Build final response + finish_reason = "stop" if not accumulated_tool_calls else "tool_calls" + + return { + "id": response_id or "chatgpt-response", + "object": "chat.completion", + "created": 0, + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": accumulated_content or None, + "tool_calls": accumulated_tool_calls if accumulated_tool_calls else None, + }, + "finish_reason": finish_reason, + } + ], + "usage": self._transform_usage(usage), + } + + @trace_method + async def request_embeddings( + self, + texts: List[str], + embedding_config, + ) -> List[List[float]]: + """ChatGPT backend does not support embeddings.""" + raise NotImplementedError("ChatGPT OAuth does not support embeddings") + + @trace_method + async def convert_response_to_chat_completion( + self, + response_data: dict, + input_messages: List[PydanticMessage], + llm_config: LLMConfig, + ) -> ChatCompletionResponse: + """Convert response to ChatCompletionResponse. + + Args: + response_data: Response data (already in ChatCompletion format). + input_messages: Original input messages. + llm_config: LLM configuration. + + Returns: + ChatCompletionResponse object. + """ + # Response should already be in ChatCompletion format after transformation + return ChatCompletionResponse(**response_data) + + @trace_method + async def stream_async( + self, + request_data: dict, + llm_config: LLMConfig, + ) -> AsyncStream[ResponseStreamEvent]: + """Stream response from ChatGPT backend. + + Note: ChatGPT backend uses SSE by default. This returns a custom + async generator that yields ResponseStreamEvent objects compatible + with the OpenAI SDK. + + Args: + request_data: Request payload. + llm_config: LLM configuration. + + Returns: + Async generator yielding ResponseStreamEvent objects. + """ + _, creds = await self._get_provider_and_credentials_async(llm_config) + headers = self._build_headers(creds) + + endpoint = llm_config.model_endpoint or CHATGPT_CODEX_ENDPOINT + + async def stream_generator(): + event_count = 0 + # Track output item index for proper event construction + output_index = 0 + # Track sequence_number in case backend doesn't provide it + # (OpenAI SDK expects incrementing sequence numbers starting at 0) + sequence_counter = 0 + + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + endpoint, + json=request_data, + headers=headers, + timeout=120.0, + ) as response: + # Check for error status + if response.status_code != 200: + error_body = await response.aread() + logger.error(f"ChatGPT SSE error: {response.status_code} - {error_body}") + raise self._handle_http_error_from_status(response.status_code, error_body.decode()) + + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + + data_str = line[6:] + if data_str == "[DONE]": + break + + try: + raw_event = json.loads(data_str) + event_type = raw_event.get("type") + event_count += 1 + + # Use backend-provided sequence_number if available, else use counter + # This ensures proper ordering even if backend doesn't provide it + if "sequence_number" not in raw_event: + raw_event["sequence_number"] = sequence_counter + sequence_counter = raw_event["sequence_number"] + 1 + + # Track output index for output_item.added events + if event_type == "response.output_item.added": + output_index = raw_event.get("output_index", output_index) + + # Convert to OpenAI SDK ResponseStreamEvent + sdk_event = self._convert_to_sdk_event(raw_event, output_index) + if sdk_event: + yield sdk_event + + except json.JSONDecodeError: + logger.warning(f"Failed to parse SSE event: {data_str[:100]}") + continue + + # Wrap the async generator in AsyncStreamWrapper to provide context manager protocol + return AsyncStreamWrapper(stream_generator()) + + def _convert_to_sdk_event( + self, + raw_event: dict, + output_index: int = 0, + ) -> Optional[ResponseStreamEvent]: + """Convert raw ChatGPT backend SSE event to OpenAI SDK ResponseStreamEvent. + + Uses model_construct() to bypass validation since ChatGPT backend doesn't + provide all fields required by OpenAI SDK (e.g., sequence_number). + + Args: + raw_event: Raw SSE event data from ChatGPT backend. + output_index: Current output item index. + + Returns: + OpenAI SDK ResponseStreamEvent or None if event type not handled. + """ + event_type = raw_event.get("type") + response_id = raw_event.get("response_id", "") + seq_num = raw_event.get("sequence_number", 0) + + # response.created -> ResponseCreatedEvent + if event_type == "response.created": + response_data = raw_event.get("response", {}) + return ResponseCreatedEvent.model_construct( + type="response.created", + sequence_number=seq_num, + response=Response.model_construct( + id=response_data.get("id", response_id), + created_at=response_data.get("created_at", 0), + model=response_data.get("model", ""), + object="response", + output=[], + status=response_data.get("status", "in_progress"), + parallel_tool_calls=response_data.get("parallel_tool_calls", True), + ), + ) + + # response.in_progress -> ResponseInProgressEvent + elif event_type == "response.in_progress": + response_data = raw_event.get("response", {}) + return ResponseInProgressEvent.model_construct( + type="response.in_progress", + sequence_number=seq_num, + response=Response.model_construct( + id=response_data.get("id", response_id), + created_at=response_data.get("created_at", 0), + model=response_data.get("model", ""), + object="response", + output=[], + status="in_progress", + parallel_tool_calls=response_data.get("parallel_tool_calls", True), + ), + ) + + # response.output_item.added -> ResponseOutputItemAddedEvent + elif event_type == "response.output_item.added": + item_data = raw_event.get("item", {}) + item_type = item_data.get("type") + idx = raw_event.get("output_index", output_index) + + if item_type == "message": + item = ResponseOutputMessage.model_construct( + id=item_data.get("id", ""), + type="message", + role=item_data.get("role", "assistant"), + content=[], + status=item_data.get("status", "in_progress"), + ) + elif item_type == "function_call": + item = ResponseFunctionToolCall.model_construct( + id=item_data.get("id", ""), + type="function_call", + call_id=item_data.get("call_id", ""), + name=item_data.get("name", ""), + arguments=item_data.get("arguments", ""), + status=item_data.get("status", "in_progress"), + ) + elif item_type == "reasoning": + # Reasoning item for o-series, GPT-5 models + item = ResponseReasoningItem.model_construct( + id=item_data.get("id", ""), + type="reasoning", + summary=item_data.get("summary", []), + status=item_data.get("status", "in_progress"), + ) + else: + # Unknown item type, skip + return None + + return ResponseOutputItemAddedEvent.model_construct( + type="response.output_item.added", + sequence_number=seq_num, + output_index=idx, + item=item, + ) + + # response.content_part.added -> ResponseContentPartAddedEvent + elif event_type == "response.content_part.added": + part_data = raw_event.get("part", {}) + return ResponseContentPartAddedEvent.model_construct( + type="response.content_part.added", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + content_index=raw_event.get("content_index", 0), + part=ResponseOutputText.model_construct( + type="output_text", + text=part_data.get("text", ""), + annotations=[], + ), + ) + + # response.output_text.delta -> ResponseTextDeltaEvent + # Note: OpenAI SDK uses "response.output_text.delta" (matching ChatGPT backend) + elif event_type == "response.output_text.delta": + return ResponseTextDeltaEvent.model_construct( + type="response.output_text.delta", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + content_index=raw_event.get("content_index", 0), + delta=raw_event.get("delta", ""), + ) + + # response.output_text.done -> ResponseTextDoneEvent + elif event_type == "response.output_text.done": + return ResponseTextDoneEvent.model_construct( + type="response.output_text.done", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + content_index=raw_event.get("content_index", 0), + text=raw_event.get("text", ""), + ) + + # response.function_call_arguments.delta -> ResponseFunctionCallArgumentsDeltaEvent + elif event_type == "response.function_call_arguments.delta": + return ResponseFunctionCallArgumentsDeltaEvent.model_construct( + type="response.function_call_arguments.delta", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + call_id=raw_event.get("call_id", ""), + delta=raw_event.get("delta", ""), + ) + + # response.function_call_arguments.done -> ResponseFunctionCallArgumentsDoneEvent + elif event_type == "response.function_call_arguments.done": + return ResponseFunctionCallArgumentsDoneEvent.model_construct( + type="response.function_call_arguments.done", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + call_id=raw_event.get("call_id", ""), + arguments=raw_event.get("arguments", ""), + ) + + # response.content_part.done -> ResponseContentPartDoneEvent + elif event_type == "response.content_part.done": + part_data = raw_event.get("part", {}) + return ResponseContentPartDoneEvent.model_construct( + type="response.content_part.done", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + content_index=raw_event.get("content_index", 0), + part=ResponseOutputText.model_construct( + type="output_text", + text=part_data.get("text", ""), + annotations=[], + ), + ) + + # response.output_item.done -> ResponseOutputItemDoneEvent + elif event_type == "response.output_item.done": + item_data = raw_event.get("item", {}) + item_type = item_data.get("type") + idx = raw_event.get("output_index", output_index) + + if item_type == "message": + # Build content from item data + content_list = [] + for c in item_data.get("content", []): + if c.get("type") in ("output_text", "text"): + content_list.append( + ResponseOutputText.model_construct( + type="output_text", + text=c.get("text", ""), + annotations=[], + ) + ) + item = ResponseOutputMessage.model_construct( + id=item_data.get("id", ""), + type="message", + role=item_data.get("role", "assistant"), + content=content_list, + status=item_data.get("status", "completed"), + ) + elif item_type == "function_call": + item = ResponseFunctionToolCall.model_construct( + id=item_data.get("id", ""), + type="function_call", + call_id=item_data.get("call_id", ""), + name=item_data.get("name", ""), + arguments=item_data.get("arguments", ""), + status=item_data.get("status", "completed"), + ) + elif item_type == "reasoning": + # Build summary from item data + summary_list = item_data.get("summary", []) + item = ResponseReasoningItem.model_construct( + id=item_data.get("id", ""), + type="reasoning", + summary=summary_list, + status=item_data.get("status", "completed"), + ) + else: + return None + + return ResponseOutputItemDoneEvent.model_construct( + type="response.output_item.done", + sequence_number=seq_num, + output_index=idx, + item=item, + ) + + # response.completed or response.done -> ResponseCompletedEvent + elif event_type in ("response.completed", "response.done"): + response_data = raw_event.get("response", {}) + + # Build output items from response data + output_items = [] + for out in response_data.get("output", []): + out_type = out.get("type") + if out_type == "message": + content_list = [] + for c in out.get("content", []): + if c.get("type") in ("output_text", "text"): + content_list.append( + ResponseOutputText.model_construct( + type="output_text", + text=c.get("text", ""), + annotations=[], + ) + ) + output_items.append( + ResponseOutputMessage.model_construct( + id=out.get("id", ""), + type="message", + role=out.get("role", "assistant"), + content=content_list, + status=out.get("status", "completed"), + ) + ) + elif out_type == "function_call": + output_items.append( + ResponseFunctionToolCall.model_construct( + id=out.get("id", ""), + type="function_call", + call_id=out.get("call_id", ""), + name=out.get("name", ""), + arguments=out.get("arguments", ""), + status=out.get("status", "completed"), + ) + ) + + return ResponseCompletedEvent.model_construct( + type="response.completed", + sequence_number=seq_num, + response=Response.model_construct( + id=response_data.get("id", response_id), + created_at=response_data.get("created_at", 0), + model=response_data.get("model", ""), + object="response", + output=output_items, + status=response_data.get("status", "completed"), + parallel_tool_calls=response_data.get("parallel_tool_calls", True), + usage=response_data.get("usage"), + ), + ) + + # Reasoning events (for o-series, GPT-5 models) + # response.reasoning_summary_part.added -> ResponseReasoningSummaryPartAddedEvent + elif event_type == "response.reasoning_summary_part.added": + part_data = raw_event.get("part", {}) + # Use a simple dict for Part since we use model_construct + part = {"text": part_data.get("text", ""), "type": part_data.get("type", "summary_text")} + return ResponseReasoningSummaryPartAddedEvent.model_construct( + type="response.reasoning_summary_part.added", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + summary_index=raw_event.get("summary_index", 0), + part=part, + ) + + # response.reasoning_summary_text.delta -> ResponseReasoningSummaryTextDeltaEvent + elif event_type == "response.reasoning_summary_text.delta": + return ResponseReasoningSummaryTextDeltaEvent.model_construct( + type="response.reasoning_summary_text.delta", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + summary_index=raw_event.get("summary_index", 0), + delta=raw_event.get("delta", ""), + ) + + # response.reasoning_summary_text.done -> ResponseReasoningSummaryTextDoneEvent + elif event_type == "response.reasoning_summary_text.done": + return ResponseReasoningSummaryTextDoneEvent.model_construct( + type="response.reasoning_summary_text.done", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + summary_index=raw_event.get("summary_index", 0), + text=raw_event.get("text", ""), + ) + + # response.reasoning_summary_part.done -> ResponseReasoningSummaryPartDoneEvent + elif event_type == "response.reasoning_summary_part.done": + part_data = raw_event.get("part", {}) + part = {"text": part_data.get("text", ""), "type": part_data.get("type", "summary_text")} + return ResponseReasoningSummaryPartDoneEvent.model_construct( + type="response.reasoning_summary_part.done", + sequence_number=seq_num, + item_id=raw_event.get("item_id", ""), + output_index=raw_event.get("output_index", output_index), + summary_index=raw_event.get("summary_index", 0), + part=part, + ) + + # Unhandled event types + return None + + def _handle_http_error_from_status(self, status_code: int, error_body: str) -> Exception: + """Create appropriate exception from HTTP status code. + + Args: + status_code: HTTP status code. + error_body: Error response body. + + Returns: + Appropriate LLM exception. + """ + if status_code == 401: + return LLMAuthenticationError( + message=f"ChatGPT authentication failed: {error_body}", + code=ErrorCode.UNAUTHENTICATED, + ) + elif status_code == 429: + return LLMRateLimitError( + message=f"ChatGPT rate limit exceeded: {error_body}", + code=ErrorCode.RATE_LIMITED, + ) + elif status_code >= 500: + return LLMServerError( + message=f"ChatGPT server error: {error_body}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) + else: + return LLMBadRequestError( + message=f"ChatGPT request failed ({status_code}): {error_body}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) + + def is_reasoning_model(self, llm_config: LLMConfig) -> bool: + """Check if model is a reasoning model. + + Args: + llm_config: LLM configuration. + + Returns: + True if model supports extended reasoning. + """ + model = llm_config.model.lower() + return "o1" in model or "o3" in model or "o4" in model or "gpt-5" in model + + @trace_method + def handle_llm_error(self, e: Exception) -> Exception: + """Map ChatGPT-specific errors to common LLMError types. + + Args: + e: Original exception. + + Returns: + Mapped LLMError subclass. + """ + if isinstance(e, httpx.HTTPStatusError): + return self._handle_http_error(e) + + return super().handle_llm_error(e) + + def _handle_http_error(self, e: httpx.HTTPStatusError) -> Exception: + """Handle HTTP status errors from ChatGPT backend. + + Args: + e: HTTP status error. + + Returns: + Appropriate LLMError subclass. + """ + status_code = e.response.status_code + error_text = str(e) + + try: + error_json = e.response.json() + error_message = error_json.get("error", {}).get("message", error_text) + except Exception: + error_message = error_text + + if status_code == 401: + return LLMAuthenticationError( + message=f"ChatGPT authentication failed: {error_message}", + code=ErrorCode.UNAUTHENTICATED, + ) + elif status_code == 429: + return LLMRateLimitError( + message=f"ChatGPT rate limit exceeded: {error_message}", + code=ErrorCode.RATE_LIMIT_EXCEEDED, + ) + elif status_code == 400: + if "context" in error_message.lower() or "token" in error_message.lower(): + return ContextWindowExceededError( + message=f"ChatGPT context window exceeded: {error_message}", + ) + return LLMBadRequestError( + message=f"ChatGPT bad request: {error_message}", + code=ErrorCode.INVALID_ARGUMENT, + ) + elif status_code >= 500: + return LLMServerError( + message=f"ChatGPT server error: {error_message}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) + else: + return LLMBadRequestError( + message=f"ChatGPT request failed ({status_code}): {error_message}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 264d7e2f..805e4038 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -100,6 +100,13 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.chatgpt_oauth: + from letta.llm_api.chatgpt_oauth_client import ChatGPTOAuthClient + + return ChatGPTOAuthClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case _: from letta.llm_api.openai_client import OpenAIClient diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index ecd66e18..41cc4d9c 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -55,6 +55,7 @@ class ProviderType(str, Enum): azure = "azure" bedrock = "bedrock" cerebras = "cerebras" + chatgpt_oauth = "chatgpt_oauth" deepseek = "deepseek" google_ai = "google_ai" google_vertex = "google_vertex" diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 3f1c484d..5425fe12 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -49,6 +49,7 @@ class LLMConfig(BaseModel): "deepseek", "xai", "zai", + "chatgpt_oauth", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") provider_name: Optional[str] = Field(None, description="The provider name for the model.") @@ -308,6 +309,8 @@ class LLMConfig(BaseModel): AnthropicThinking, AzureModelSettings, BedrockModelSettings, + ChatGPTOAuthModelSettings, + ChatGPTOAuthReasoning, DeepseekModelSettings, GeminiThinkingConfig, GoogleAIModelSettings, @@ -382,7 +385,16 @@ class LLMConfig(BaseModel): temperature=self.temperature, ) elif self.model_endpoint_type == "bedrock": - return Model(max_output_tokens=self.max_tokens or 4096) + return BedrockModelSettings( + max_output_tokens=self.max_tokens or 4096, + temperature=self.temperature, + ) + elif self.model_endpoint_type == "chatgpt_oauth": + return ChatGPTOAuthModelSettings( + max_output_tokens=self.max_tokens or 4096, + temperature=self.temperature, + reasoning=ChatGPTOAuthReasoning(reasoning_effort=self.reasoning_effort or "medium"), + ) else: # If we don't know the model type, use the default Model schema return Model(max_output_tokens=self.max_tokens or 4096) diff --git a/letta/schemas/model.py b/letta/schemas/model.py index daf3291e..7c059452 100644 --- a/letta/schemas/model.py +++ b/letta/schemas/model.py @@ -48,6 +48,7 @@ class Model(LLMConfig, ModelBase): "deepseek", "xai", "zai", + "chatgpt_oauth", ] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True) context_window: int = Field( ..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True @@ -434,6 +435,32 @@ class BedrockModelSettings(ModelSettings): } +class ChatGPTOAuthReasoning(BaseModel): + """Reasoning configuration for ChatGPT OAuth models (GPT-5.x, o-series).""" + + reasoning_effort: Literal["none", "low", "medium", "high", "xhigh"] = Field( + "medium", description="The reasoning effort level for GPT-5.x and o-series models." + ) + + +class ChatGPTOAuthModelSettings(ModelSettings): + """ChatGPT OAuth model configuration (uses ChatGPT backend API).""" + + provider_type: Literal[ProviderType.chatgpt_oauth] = Field(ProviderType.chatgpt_oauth, description="The type of the provider.") + temperature: float = Field(0.7, description="The temperature of the model.") + reasoning: ChatGPTOAuthReasoning = Field( + ChatGPTOAuthReasoning(reasoning_effort="medium"), description="The reasoning configuration for the model." + ) + + def _to_legacy_config_params(self) -> dict: + return { + "temperature": self.temperature, + "max_tokens": self.max_output_tokens, + "reasoning_effort": self.reasoning.reasoning_effort, + "parallel_tool_calls": self.parallel_tool_calls, + } + + ModelSettingsUnion = Annotated[ Union[ OpenAIModelSettings, @@ -447,6 +474,7 @@ ModelSettingsUnion = Annotated[ DeepseekModelSettings, TogetherModelSettings, BedrockModelSettings, + ChatGPTOAuthModelSettings, ], Field(discriminator="provider_type"), ] diff --git a/letta/schemas/providers/__init__.py b/letta/schemas/providers/__init__.py index 8486ee51..6e3f5187 100644 --- a/letta/schemas/providers/__init__.py +++ b/letta/schemas/providers/__init__.py @@ -5,6 +5,7 @@ from .azure import AzureProvider from .base import Provider, ProviderBase, ProviderCheck, ProviderCreate, ProviderUpdate from .bedrock import BedrockProvider from .cerebras import CerebrasProvider +from .chatgpt_oauth import ChatGPTOAuthProvider from .deepseek import DeepSeekProvider from .google_gemini import GoogleAIProvider from .google_vertex import GoogleVertexProvider @@ -31,7 +32,8 @@ __all__ = [ "AnthropicProvider", "AzureProvider", "BedrockProvider", - "CerebrasProvider", # NEW + "CerebrasProvider", + "ChatGPTOAuthProvider", "DeepSeekProvider", "GoogleAIProvider", "GoogleVertexProvider", diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py index 1f11e8cd..44e63f46 100644 --- a/letta/schemas/providers/base.py +++ b/letta/schemas/providers/base.py @@ -184,6 +184,7 @@ class Provider(ProviderBase): AzureProvider, BedrockProvider, CerebrasProvider, + ChatGPTOAuthProvider, DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, @@ -229,6 +230,8 @@ class Provider(ProviderBase): return DeepSeekProvider(**self.model_dump(exclude_none=True)) case ProviderType.cerebras: return CerebrasProvider(**self.model_dump(exclude_none=True)) + case ProviderType.chatgpt_oauth: + return ChatGPTOAuthProvider(**self.model_dump(exclude_none=True)) case ProviderType.xai: return XAIProvider(**self.model_dump(exclude_none=True)) case ProviderType.zai: diff --git a/letta/schemas/providers/chatgpt_oauth.py b/letta/schemas/providers/chatgpt_oauth.py new file mode 100644 index 00000000..8df9cd60 --- /dev/null +++ b/letta/schemas/providers/chatgpt_oauth.py @@ -0,0 +1,366 @@ +"""ChatGPT OAuth Provider - uses chatgpt.com/backend-api/codex with OAuth authentication.""" + +import json +from datetime import datetime +from typing import TYPE_CHECKING, Literal, Optional + +import httpx +from pydantic import BaseModel, Field + +from letta.errors import ErrorCode, LLMAuthenticationError, LLMError +from letta.log import get_logger +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.base import Provider +from letta.schemas.secret import Secret + +if TYPE_CHECKING: + from letta.orm import User + +logger = get_logger(__name__) + +# ChatGPT Backend API Configuration +CHATGPT_CODEX_ENDPOINT = "https://chatgpt.com/backend-api/codex/responses" +CHATGPT_TOKEN_REFRESH_URL = "https://auth.openai.com/oauth/token" + +# OAuth client_id for Codex CLI (required for token refresh) +# Must match the client_id used in the initial OAuth authorization flow +# This is the public client_id used by Codex CLI / Letta Code +CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" + +# Token refresh buffer (refresh 5 minutes before expiry) +TOKEN_REFRESH_BUFFER_SECONDS = 300 + +# Hardcoded models available via ChatGPT backend +# These are models that can be accessed through ChatGPT Plus/Pro subscriptions +# Model list based on opencode-openai-codex-auth plugin presets +# Reasoning effort levels are configured via llm_config.reasoning_effort +CHATGPT_MODELS = [ + # GPT-5.2 models (supports none/low/medium/high/xhigh reasoning) + {"name": "gpt-5.2", "context_window": 272000}, + {"name": "gpt-5.2-codex", "context_window": 272000}, + # GPT-5.1 models + {"name": "gpt-5.1", "context_window": 272000}, + {"name": "gpt-5.1-codex", "context_window": 272000}, + {"name": "gpt-5.1-codex-mini", "context_window": 272000}, + {"name": "gpt-5.1-codex-max", "context_window": 272000}, + # GPT-5 Codex models (original) + {"name": "gpt-5-codex-mini", "context_window": 272000}, + # GPT-4 models (for ChatGPT Plus users) + {"name": "gpt-4o", "context_window": 128000}, + {"name": "gpt-4o-mini", "context_window": 128000}, + {"name": "o1", "context_window": 200000}, + {"name": "o1-pro", "context_window": 200000}, + {"name": "o3", "context_window": 200000}, + {"name": "o3-mini", "context_window": 200000}, + {"name": "o4-mini", "context_window": 200000}, +] + + +class ChatGPTOAuthCredentials(BaseModel): + """OAuth credentials for ChatGPT backend API access. + + These credentials are stored as JSON in the provider's api_key_enc field. + """ + + access_token: str = Field(..., description="OAuth access token for ChatGPT API") + refresh_token: str = Field(..., description="OAuth refresh token for obtaining new access tokens") + account_id: str = Field(..., description="ChatGPT account ID for the ChatGPT-Account-Id header") + expires_at: int = Field(..., description="Unix timestamp when the access_token expires") + + def is_expired(self, buffer_seconds: int = TOKEN_REFRESH_BUFFER_SECONDS) -> bool: + """Check if token is expired or will expire within buffer_seconds. + + Handles both seconds and milliseconds timestamps (auto-detects based on magnitude). + """ + expires_at = self.expires_at + # Auto-detect milliseconds (13+ digits) vs seconds (10 digits) + # Timestamps > 10^12 are definitely milliseconds (year 33658 in seconds) + if expires_at > 10**12: + expires_at = expires_at // 1000 # Convert ms to seconds + + current_time = datetime.utcnow().timestamp() + is_expired = current_time >= (expires_at - buffer_seconds) + logger.debug(f"Token expiry check: current={current_time}, expires_at={expires_at}, buffer={buffer_seconds}, expired={is_expired}") + return is_expired + + def to_json(self) -> str: + """Serialize to JSON string for storage in api_key_enc.""" + return self.model_dump_json() + + @classmethod + def from_json(cls, json_str: str) -> "ChatGPTOAuthCredentials": + """Deserialize from JSON string stored in api_key_enc.""" + data = json.loads(json_str) + return cls(**data) + + +class ChatGPTOAuthProvider(Provider): + """ + ChatGPT OAuth Provider for accessing ChatGPT's backend-api with OAuth tokens. + + This provider enables using ChatGPT Plus/Pro subscription credentials to access + OpenAI models through the ChatGPT backend API at chatgpt.com/backend-api/codex. + + OAuth credentials are stored as JSON in the api_key_enc field: + { + "access_token": "...", + "refresh_token": "...", + "account_id": "...", + "expires_at": 1234567890 + } + + The client (e.g., Letta Code) performs the OAuth flow and sends the credentials + to the backend via the provider creation API. + """ + + provider_type: Literal[ProviderType.chatgpt_oauth] = Field( + ProviderType.chatgpt_oauth, + description="The type of the provider.", + ) + provider_category: ProviderCategory = Field( + ProviderCategory.byok, # Always BYOK since it uses user's OAuth credentials + description="The category of the provider (always byok for OAuth)", + ) + base_url: str = Field( + CHATGPT_CODEX_ENDPOINT, + description="Base URL for the ChatGPT backend API.", + ) + + async def get_oauth_credentials(self) -> Optional[ChatGPTOAuthCredentials]: + """Retrieve and parse OAuth credentials from api_key_enc. + + Returns: + ChatGPTOAuthCredentials if valid credentials exist, None otherwise. + """ + if not self.api_key_enc: + return None + + json_str = await self.api_key_enc.get_plaintext_async() + if not json_str: + return None + + try: + return ChatGPTOAuthCredentials.from_json(json_str) + except (json.JSONDecodeError, ValueError) as e: + logger.error(f"Failed to parse ChatGPT OAuth credentials: {e}") + return None + + async def refresh_token_if_needed( + self, actor: Optional["User"] = None, force_refresh: bool = False + ) -> Optional[ChatGPTOAuthCredentials]: + """Check if token needs refresh and refresh if necessary. + + This method is called before each API request to ensure valid credentials. + Tokens are refreshed 5 minutes before expiry to avoid edge cases. + + Args: + actor: The user performing the action. Required for persisting refreshed credentials. + force_refresh: If True, always refresh the token regardless of expiry. For testing only. + + Returns: + Updated credentials if successful, None on failure. + """ + creds = await self.get_oauth_credentials() + if not creds: + return None + + if not creds.is_expired() and not force_refresh: + return creds + + # Token needs refresh + logger.debug(f"ChatGPT OAuth token refresh triggered (expired={creds.is_expired()}, force={force_refresh})") + + try: + new_creds = await self._perform_token_refresh(creds) + # Update stored credentials in memory and persist to database + await self._update_stored_credentials(new_creds, actor=actor) + return new_creds + except Exception as e: + logger.error(f"Failed to refresh ChatGPT OAuth token: {e}") + # If refresh fails but original access_token is still valid, use it + if not creds.is_expired(): + logger.warning("Token refresh failed, but original access_token is still valid - using existing token") + return creds + # Both refresh failed AND token is expired - return None to trigger auth error + return None + + async def _perform_token_refresh(self, creds: ChatGPTOAuthCredentials) -> ChatGPTOAuthCredentials: + """Perform OAuth token refresh with OpenAI's token endpoint. + + Args: + creds: Current credentials containing the refresh_token. + + Returns: + New ChatGPTOAuthCredentials with refreshed access_token. + + Raises: + LLMAuthenticationError: If refresh fails due to invalid credentials. + LLMError: If refresh fails due to network or server error. + """ + async with httpx.AsyncClient() as client: + try: + response = await client.post( + CHATGPT_TOKEN_REFRESH_URL, + data={ + "grant_type": "refresh_token", + "refresh_token": creds.refresh_token, + "client_id": CHATGPT_OAUTH_CLIENT_ID, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + + # Calculate new expiry time + expires_in = data.get("expires_in", 3600) + new_expires_at = int(datetime.utcnow().timestamp()) + expires_in + + new_access_token = data["access_token"] + new_refresh_token = data.get("refresh_token", creds.refresh_token) + + logger.debug(f"ChatGPT OAuth token refreshed, expires_in={expires_in}s") + + return ChatGPTOAuthCredentials( + access_token=new_access_token, + refresh_token=new_refresh_token, + account_id=creds.account_id, # Account ID doesn't change + expires_at=new_expires_at, + ) + except httpx.HTTPStatusError as e: + # Log full error details for debugging + try: + error_body = e.response.json() + logger.error(f"Token refresh HTTP error: {e.response.status_code} - JSON: {error_body}") + except Exception: + logger.error(f"Token refresh HTTP error: {e.response.status_code} - Text: {e.response.text}") + if e.response.status_code == 401: + raise LLMAuthenticationError( + message="Failed to refresh ChatGPT OAuth token: refresh token is invalid or expired", + code=ErrorCode.UNAUTHENTICATED, + ) + raise LLMError( + message=f"Failed to refresh ChatGPT OAuth token: {e}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) + except Exception as e: + logger.error(f"Token refresh error: {type(e).__name__}: {e}") + raise LLMError( + message=f"Failed to refresh ChatGPT OAuth token: {e}", + code=ErrorCode.INTERNAL_SERVER_ERROR, + ) + + async def _update_stored_credentials(self, creds: ChatGPTOAuthCredentials, actor: Optional["User"] = None) -> None: + """Update stored credentials in memory and persist to database. + + Args: + creds: New credentials to store. + actor: The user performing the action. Required for database persistence. + """ + new_secret = await Secret.from_plaintext_async(creds.to_json()) + # Update in-memory value + object.__setattr__(self, "api_key_enc", new_secret) + + # Persist to database if we have an actor and provider ID + if actor and self.id: + try: + from letta.schemas.providers.base import ProviderUpdate + from letta.services.provider_manager import ProviderManager + + provider_manager = ProviderManager() + await provider_manager.update_provider_async( + provider_id=self.id, + provider_update=ProviderUpdate(api_key=creds.to_json()), + actor=actor, + ) + except Exception as e: + logger.error(f"Failed to persist refreshed credentials to database: {e}") + # Don't fail the request - we have valid credentials in memory + + async def check_api_key(self): + """Validate the OAuth credentials by checking token validity. + + Raises: + ValueError: If no credentials are configured. + LLMAuthenticationError: If credentials are invalid. + """ + creds = await self.get_oauth_credentials() + if not creds: + raise ValueError("No ChatGPT OAuth credentials configured") + + # Try to refresh if needed + creds = await self.refresh_token_if_needed() + if not creds: + raise LLMAuthenticationError( + message="Failed to obtain valid ChatGPT OAuth credentials", + code=ErrorCode.UNAUTHENTICATED, + ) + + # Optionally make a test request to validate + # For now, we just verify we have valid-looking credentials + if not creds.access_token or not creds.account_id: + raise LLMAuthenticationError( + message="ChatGPT OAuth credentials are incomplete", + code=ErrorCode.UNAUTHENTICATED, + ) + + def get_default_max_output_tokens(self, model_name: str) -> int: + """Get the default max output tokens for ChatGPT models.""" + # Reasoning models (o-series) have higher limits + if model_name.startswith("o1") or model_name.startswith("o3") or model_name.startswith("o4"): + return 100000 + # GPT-5.x models + elif "gpt-5" in model_name: + return 16384 + # GPT-4 models + elif "gpt-4" in model_name: + return 16384 + return 4096 + + async def list_llm_models_async(self) -> list[LLMConfig]: + """List available models from ChatGPT backend. + + Returns a hardcoded list of models available via ChatGPT Plus/Pro subscriptions. + """ + creds = await self.get_oauth_credentials() + if not creds: + logger.warning("Cannot list models: no valid ChatGPT OAuth credentials") + return [] + + configs = [] + for model in CHATGPT_MODELS: + model_name = model["name"] + context_window = model["context_window"] + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="chatgpt_oauth", + model_endpoint=self.base_url, + context_window=context_window, + handle=self.get_handle(model_name), + max_tokens=self.get_default_max_output_tokens(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + + async def list_embedding_models_async(self) -> list: + """ChatGPT backend does not support embedding models.""" + return [] + + def get_model_context_window(self, model_name: str) -> int | None: + """Get the context window for a model.""" + for model in CHATGPT_MODELS: + if model["name"] == model_name: + return model["context_window"] + return 128000 # Default + + async def get_model_context_window_async(self, model_name: str) -> int | None: + """Get the context window for a model (async version).""" + return self.get_model_context_window(model_name) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index f8fbcbc9..509fd088 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1542,6 +1542,7 @@ async def send_message( "zai", "groq", "deepseek", + "chatgpt_oauth", ] # Create a new run for execution tracking @@ -2126,6 +2127,7 @@ async def preview_model_request( "zai", "groq", "deepseek", + "chatgpt_oauth", ] if agent_eligible and model_compatible: @@ -2180,6 +2182,7 @@ async def summarize_messages( "zai", "groq", "deepseek", + "chatgpt_oauth", ] if agent_eligible and model_compatible: diff --git a/letta/server/server.py b/letta/server/server.py index 82527bf2..196c35f8 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1693,7 +1693,7 @@ class SyncServer(object): # TODO: cleanup this logic llm_config = letta_agent.agent_state.llm_config # supports_token_streaming = ["openai", "anthropic", "xai", "deepseek"] - supports_token_streaming = ["openai", "anthropic", "deepseek"] # TODO re-enable xAI once streaming is patched + supports_token_streaming = ["openai", "anthropic", "deepseek", "chatgpt_oauth"] # TODO re-enable xAI once streaming is patched if stream_tokens and (llm_config.model_endpoint_type not in supports_token_streaming): logger.warning( f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." @@ -1825,7 +1825,7 @@ class SyncServer(object): letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor) llm_config = letta_multi_agent.agent_state.llm_config - supports_token_streaming = ["openai", "anthropic", "deepseek"] + supports_token_streaming = ["openai", "anthropic", "deepseek", "chatgpt_oauth"] if stream_tokens and (llm_config.model_endpoint_type not in supports_token_streaming): logger.warning( f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 87fee46c..bdf5e0e2 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -482,6 +482,7 @@ class ProviderManager: try: # Get the provider class and create an instance + from letta.schemas.enums import ProviderType from letta.schemas.providers.anthropic import AnthropicProvider from letta.schemas.providers.azure import AzureProvider from letta.schemas.providers.bedrock import BedrockProvider @@ -491,42 +492,47 @@ class ProviderManager: from letta.schemas.providers.openai import OpenAIProvider from letta.schemas.providers.zai import ZAIProvider - provider_type_to_class = { - "openai": OpenAIProvider, - "anthropic": AnthropicProvider, - "groq": GroqProvider, - "google": GoogleAIProvider, - "ollama": OllamaProvider, - "bedrock": BedrockProvider, - "azure": AzureProvider, - "zai": ZAIProvider, - } + # ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id + # (needed for OAuth token refresh and database persistence) + if provider.provider_type == ProviderType.chatgpt_oauth: + provider_instance = provider.cast_to_subtype() + else: + provider_type_to_class = { + "openai": OpenAIProvider, + "anthropic": AnthropicProvider, + "groq": GroqProvider, + "google": GoogleAIProvider, + "ollama": OllamaProvider, + "bedrock": BedrockProvider, + "azure": AzureProvider, + "zai": ZAIProvider, + } - provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type) - provider_class = provider_type_to_class.get(provider_type) + provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type) + provider_class = provider_type_to_class.get(provider_type) - if not provider_class: - logger.warning(f"No provider class found for type '{provider_type}'") - return + if not provider_class: + logger.warning(f"No provider class found for type '{provider_type}'") + return - # Create provider instance with necessary parameters - api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None - access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None - kwargs = { - "name": provider.name, - "api_key": api_key, - "provider_category": provider.provider_category, - } - if provider.base_url: - kwargs["base_url"] = provider.base_url - if access_key: - kwargs["access_key"] = access_key - if provider.region: - kwargs["region"] = provider.region - if provider.api_version: - kwargs["api_version"] = provider.api_version + # Create provider instance with necessary parameters + api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None + access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None + kwargs = { + "name": provider.name, + "api_key": api_key, + "provider_category": provider.provider_category, + } + if provider.base_url: + kwargs["base_url"] = provider.base_url + if access_key: + kwargs["access_key"] = access_key + if provider.region: + kwargs["region"] = provider.region + if provider.api_version: + kwargs["api_version"] = provider.api_version - provider_instance = provider_class(**kwargs) + provider_instance = provider_class(**kwargs) # Query the provider's API for available models llm_models = await provider_instance.list_llm_models_async() diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index 79d46673..313b0d5f 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -493,11 +493,12 @@ class StreamingService: "zai", "groq", "deepseek", + "chatgpt_oauth", ] def _is_token_streaming_compatible(self, agent: AgentState) -> bool: """Check if agent's model supports token-level streaming.""" - base_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek", "zai"] + base_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek", "zai", "chatgpt_oauth"] google_letta_v1 = agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in [ "google_ai", "google_vertex",