diff --git a/fern/openapi.json b/fern/openapi.json index 0f50834d..3478fed9 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -29192,43 +29192,6 @@ "title": "ChatCompletionSystemMessageParam", "description": "Developer-provided instructions that the model should follow, regardless of\nmessages sent by the user. With o1 models and newer, use `developer` messages\nfor this purpose instead." }, - "ChatCompletionTokenLogprob": { - "properties": { - "token": { - "type": "string", - "title": "Token" - }, - "bytes": { - "anyOf": [ - { - "items": { - "type": "integer" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Bytes" - }, - "logprob": { - "type": "number", - "title": "Logprob" - }, - "top_logprobs": { - "items": { - "$ref": "#/components/schemas/TopLogprob" - }, - "type": "array", - "title": "Top Logprobs" - } - }, - "additionalProperties": true, - "type": "object", - "required": ["token", "logprob", "top_logprobs"], - "title": "ChatCompletionTokenLogprob" - }, "ChatCompletionToolMessageParam": { "properties": { "content": { @@ -29453,7 +29416,7 @@ "logprobs": { "anyOf": [ { - "$ref": "#/components/schemas/ChoiceLogprobs" + "$ref": "#/components/schemas/openai__types__chat__chat_completion__ChoiceLogprobs" }, { "type": "null" @@ -29469,42 +29432,6 @@ "required": ["finish_reason", "index", "message"], "title": "Choice" }, - "ChoiceLogprobs": { - "properties": { - "content": { - "anyOf": [ - { - "items": { - "$ref": "#/components/schemas/ChatCompletionTokenLogprob" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Content" - }, - "refusal": { - "anyOf": [ - { - "items": { - "$ref": "#/components/schemas/ChatCompletionTokenLogprob" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Refusal" - } - }, - "additionalProperties": true, - "type": "object", - "title": "ChoiceLogprobs", - "description": "Log probability information for the choice." - }, "ClientToolSchema": { "properties": { "name": { @@ -30525,6 +30452,30 @@ "description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.", "default": false }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.", + "default": false + }, "streaming": { "type": "boolean", "title": "Streaming", @@ -36716,6 +36667,30 @@ "title": "Strict", "description": "Enable strict mode for tool calling. When true, tool schemas include strict: true and additionalProperties: false, guaranteeing tool outputs match JSON schemas.", "default": false + }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "Whether to return log probabilities of the output tokens. Useful for RL training.", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "Whether to return token IDs for all LLM generations via SGLang native endpoint. Required for multi-turn RL training with loss masking. Only works with SGLang provider.", + "default": false } }, "type": "object", @@ -36888,6 +36863,30 @@ "description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.", "default": false }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.", + "default": false + }, "callback_url": { "anyOf": [ { @@ -37083,6 +37082,30 @@ "description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.", "default": false }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.", + "default": false + }, "agent_id": { "type": "string", "maxLength": 42, @@ -37457,6 +37480,30 @@ "title": "Include Compaction Messages", "description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.", "default": false + }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.", + "default": false } }, "type": "object", @@ -37517,6 +37564,32 @@ "usage": { "$ref": "#/components/schemas/LettaUsageStatistics", "description": "The usage statistics of the agent." + }, + "logprobs": { + "anyOf": [ + { + "$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__ChoiceLogprobs" + }, + { + "type": "null" + } + ], + "description": "Log probabilities of the output tokens from the last LLM call. Only present if return_logprobs was enabled." + }, + "turns": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/TurnTokenData" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Turns", + "description": "Token data for all LLM generations in multi-turn agent interaction. Includes token IDs and logprobs for each assistant turn, plus tool result content. Only present if return_token_ids was enabled. Used for RL training with loss masking." } }, "type": "object", @@ -37708,6 +37781,30 @@ "description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.", "default": false }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.", + "default": false + }, "streaming": { "type": "boolean", "title": "Streaming", @@ -39265,6 +39362,30 @@ "description": "Enable strict mode for tool calling. When true, tool schemas include strict: true and additionalProperties: false, guaranteeing tool outputs match JSON schemas.", "default": false }, + "return_logprobs": { + "type": "boolean", + "title": "Return Logprobs", + "description": "Whether to return log probabilities of the output tokens. Useful for RL training.", + "default": false + }, + "top_logprobs": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Top Logprobs", + "description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True." + }, + "return_token_ids": { + "type": "boolean", + "title": "Return Token Ids", + "description": "Whether to return token IDs for all LLM generations via SGLang native endpoint. Required for multi-turn RL training with loss masking. Only works with SGLang provider.", + "default": false + }, "max_context_window": { "type": "integer", "title": "Max Context Window", @@ -45313,13 +45434,15 @@ "type": "object", "title": "ToolUpdate" }, - "TopLogprob": { + "TurnTokenData": { "properties": { - "token": { + "role": { "type": "string", - "title": "Token" + "enum": ["assistant", "tool"], + "title": "Role", + "description": "Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)." }, - "bytes": { + "output_ids": { "anyOf": [ { "items": { @@ -45331,17 +45454,54 @@ "type": "null" } ], - "title": "Bytes" + "title": "Output Ids", + "description": "Token IDs from SGLang native endpoint. Only present for assistant turns." }, - "logprob": { - "type": "number", - "title": "Logprob" + "output_token_logprobs": { + "anyOf": [ + { + "items": { + "items": {}, + "type": "array" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Output Token Logprobs", + "description": "Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns." + }, + "content": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Content", + "description": "Text content. For tool turns, client tokenizes this with loss_mask=0." + }, + "tool_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Tool Name", + "description": "Name of the tool called. Only present for tool turns." } }, - "additionalProperties": true, "type": "object", - "required": ["token", "logprob"], - "title": "TopLogprob" + "required": ["role"], + "title": "TurnTokenData", + "description": "Token data for a single LLM generation turn in a multi-turn agent interaction.\n\nUsed for RL training to track token IDs and logprobs across all LLM calls,\nnot just the final one. Tool results are included so the client can tokenize\nthem with loss_mask=0 (non-trainable)." }, "UpdateAgent": { "properties": { @@ -48653,6 +48813,105 @@ "required": ["status"], "title": "ToolReturn" }, + "letta__schemas__openai__chat_completion_response__ChatCompletionTokenLogprob": { + "properties": { + "token": { + "type": "string", + "title": "Token" + }, + "bytes": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Bytes" + }, + "logprob": { + "type": "number", + "title": "Logprob" + }, + "top_logprobs": { + "items": { + "$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__TopLogprob" + }, + "type": "array", + "title": "Top Logprobs" + } + }, + "type": "object", + "required": ["token", "logprob", "top_logprobs"], + "title": "ChatCompletionTokenLogprob" + }, + "letta__schemas__openai__chat_completion_response__ChoiceLogprobs": { + "properties": { + "content": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__ChatCompletionTokenLogprob" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Content" + }, + "refusal": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__ChatCompletionTokenLogprob" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Refusal" + } + }, + "type": "object", + "title": "ChoiceLogprobs" + }, + "letta__schemas__openai__chat_completion_response__TopLogprob": { + "properties": { + "token": { + "type": "string", + "title": "Token" + }, + "bytes": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Bytes" + }, + "logprob": { + "type": "number", + "title": "Logprob" + } + }, + "type": "object", + "required": ["token", "logprob"], + "title": "TopLogprob" + }, "letta__serialize_schemas__pydantic_agent_schema__AgentSchema": { "properties": { "agent_type": { @@ -48999,6 +49258,42 @@ "type": "object", "title": "ToolExecuteRequest" }, + "openai__types__chat__chat_completion__ChoiceLogprobs": { + "properties": { + "content": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/openai__types__chat__chat_completion_token_logprob__ChatCompletionTokenLogprob" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Content" + }, + "refusal": { + "anyOf": [ + { + "items": { + "$ref": "#/components/schemas/openai__types__chat__chat_completion_token_logprob__ChatCompletionTokenLogprob" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Refusal" + } + }, + "additionalProperties": true, + "type": "object", + "title": "ChoiceLogprobs", + "description": "Log probability information for the choice." + }, "openai__types__chat__chat_completion_message_function_tool_call__Function": { "properties": { "arguments": { @@ -49032,6 +49327,73 @@ "title": "Function", "description": "The function that the model called." }, + "openai__types__chat__chat_completion_token_logprob__ChatCompletionTokenLogprob": { + "properties": { + "token": { + "type": "string", + "title": "Token" + }, + "bytes": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Bytes" + }, + "logprob": { + "type": "number", + "title": "Logprob" + }, + "top_logprobs": { + "items": { + "$ref": "#/components/schemas/openai__types__chat__chat_completion_token_logprob__TopLogprob" + }, + "type": "array", + "title": "Top Logprobs" + } + }, + "additionalProperties": true, + "type": "object", + "required": ["token", "logprob", "top_logprobs"], + "title": "ChatCompletionTokenLogprob" + }, + "openai__types__chat__chat_completion_token_logprob__TopLogprob": { + "properties": { + "token": { + "type": "string", + "title": "Token" + }, + "bytes": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Bytes" + }, + "logprob": { + "type": "number", + "title": "Logprob" + } + }, + "additionalProperties": true, + "type": "object", + "required": ["token", "logprob"], + "title": "TopLogprob" + }, "LettaMessageUnion": { "oneOf": [ { diff --git a/letta/adapters/letta_llm_adapter.py b/letta/adapters/letta_llm_adapter.py index 2f21862d..ba14b3d4 100644 --- a/letta/adapters/letta_llm_adapter.py +++ b/letta/adapters/letta_llm_adapter.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional from letta.llm_api.llm_client_base import LLMClientBase from letta.schemas.enums import LLMCallType from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ToolCall +from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.telemetry_manager import TelemetryManager @@ -48,6 +48,10 @@ class LettaLLMAdapter(ABC): self.content: list[TextContent | ReasoningContent | RedactedReasoningContent] | None = None self.tool_call: ToolCall | None = None self.tool_calls: list[ToolCall] = [] + self.logprobs: ChoiceLogprobs | None = None + # SGLang native endpoint data (for multi-turn RL training) + self.output_ids: list[int] | None = None + self.output_token_logprobs: list[list[float]] | None = None self.usage: LettaUsageStatistics = LettaUsageStatistics() self.telemetry_manager: TelemetryManager = TelemetryManager() self.llm_request_finish_timestamp_ns: int | None = None diff --git a/letta/adapters/letta_llm_request_adapter.py b/letta/adapters/letta_llm_request_adapter.py index 8ea95680..21bc543d 100644 --- a/letta/adapters/letta_llm_request_adapter.py +++ b/letta/adapters/letta_llm_request_adapter.py @@ -83,6 +83,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter): else: self.tool_call = None + # Extract logprobs if present + self.logprobs = self.chat_completions_response.choices[0].logprobs + # Extract usage statistics self.usage.step_count = 1 self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens diff --git a/letta/adapters/sglang_native_adapter.py b/letta/adapters/sglang_native_adapter.py new file mode 100644 index 00000000..fbaa07e0 --- /dev/null +++ b/letta/adapters/sglang_native_adapter.py @@ -0,0 +1,521 @@ +""" +SGLang Native Adapter for multi-turn RL training. + +This adapter uses SGLang's native /generate endpoint instead of the OpenAI-compatible +endpoint to get token IDs and per-token logprobs, which are essential for proper +multi-turn RL training with loss masking. + +Uses HuggingFace tokenizer's apply_chat_template() for proper tool formatting. +""" + +import json +import re +import time +import uuid +from typing import Any, AsyncGenerator, Optional + +from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter +from letta.helpers.datetime_helpers import get_utc_timestamp_ns +from letta.llm_api.sglang_native_client import SGLangNativeClient +from letta.log import get_logger +from letta.schemas.letta_message import LettaMessage +from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent +from letta.schemas.openai.chat_completion_response import ( + ChatCompletionResponse, + Choice, + ChoiceLogprobs, + ChatCompletionTokenLogprob, + FunctionCall, + Message as ChoiceMessage, + ToolCall, + UsageStatistics, +) +from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens + +logger = get_logger(__name__) + +# Global tokenizer cache +_tokenizer_cache: dict[str, Any] = {} + + +class SGLangNativeAdapter(SimpleLLMRequestAdapter): + """ + Adapter that uses SGLang's native /generate endpoint for multi-turn RL training. + + Key differences from SimpleLLMRequestAdapter: + - Uses /generate instead of /v1/chat/completions + - Returns output_ids (token IDs) in addition to text + - Returns output_token_logprobs with [logprob, token_id] pairs + - Formats tools into prompt and parses tool calls from response + + These are essential for building accurate loss masks in multi-turn training. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._sglang_client: Optional[SGLangNativeClient] = None + self._tokenizer: Any = None + + def _get_tokenizer(self) -> Any: + """Get or create tokenizer for the model.""" + global _tokenizer_cache + + # Get model name from llm_config + model_name = self.llm_config.model + if not model_name: + logger.warning("No model name in llm_config, cannot load tokenizer") + return None + + # Check cache + if model_name in _tokenizer_cache: + return _tokenizer_cache[model_name] + + try: + from transformers import AutoTokenizer + logger.info(f"Loading tokenizer for model: {model_name}") + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + _tokenizer_cache[model_name] = tokenizer + return tokenizer + except ImportError: + logger.warning("transformers not installed, falling back to manual formatting") + return None + except Exception as e: + logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting") + return None + + def _get_sglang_client(self) -> SGLangNativeClient: + """Get or create SGLang native client.""" + if self._sglang_client is None: + # Get base URL from llm_config, removing /v1 suffix if present + base_url = self.llm_config.model_endpoint or "" + # SGLang local instances typically don't need API key + self._sglang_client = SGLangNativeClient( + base_url=base_url, + api_key=None, + ) + return self._sglang_client + + def _format_tools_for_prompt(self, tools: list) -> str: + """ + Format tools in Qwen3 chat template format for the system prompt. + + This matches the exact format produced by Qwen3's tokenizer.apply_chat_template() + with tools parameter. + """ + if not tools: + return "" + + # Format each tool as JSON (matching Qwen3 template exactly) + tool_jsons = [] + for tool in tools: + # Handle both dict and object formats + if isinstance(tool, dict): + # Already in OpenAI format + tool_jsons.append(json.dumps(tool)) + else: + # Convert object to dict + tool_dict = { + "type": "function", + "function": { + "name": getattr(getattr(tool, "function", tool), "name", ""), + "description": getattr(getattr(tool, "function", tool), "description", ""), + "parameters": getattr(getattr(tool, "function", tool), "parameters", {}), + } + } + tool_jsons.append(json.dumps(tool_dict)) + + # Use exact Qwen3 format + tools_section = ( + "\n\n# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + + "\n".join(tool_jsons) + "\n" + "\n\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "" + ) + + return tools_section + + def _convert_messages_to_openai_format(self, messages: list) -> list[dict]: + """Convert Letta Message objects to OpenAI-style message dicts.""" + openai_messages = [] + + for msg in messages: + # Handle both dict and Pydantic Message objects + if hasattr(msg, 'role'): + role = msg.role + content = msg.content if hasattr(msg, 'content') else "" + # Handle content that might be a list of content parts + if isinstance(content, list): + content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content]) + elif content is None: + content = "" + tool_calls = getattr(msg, 'tool_calls', None) + tool_call_id = getattr(msg, 'tool_call_id', None) + name = getattr(msg, 'name', None) + else: + role = msg.get("role", "user") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", None) + tool_call_id = msg.get("tool_call_id", None) + name = msg.get("name", None) + + openai_msg = {"role": role, "content": content} + + if tool_calls: + # Convert tool calls to OpenAI format + openai_tool_calls = [] + for tc in tool_calls: + if hasattr(tc, 'function'): + tc_dict = { + "id": getattr(tc, 'id', f"call_{uuid.uuid4().hex[:8]}"), + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments if isinstance(tc.function.arguments, str) else json.dumps(tc.function.arguments) + } + } + else: + tc_dict = { + "id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"), + "type": "function", + "function": tc.get("function", {}) + } + openai_tool_calls.append(tc_dict) + openai_msg["tool_calls"] = openai_tool_calls + + if tool_call_id: + openai_msg["tool_call_id"] = tool_call_id + + if name and role == "tool": + openai_msg["name"] = name + + openai_messages.append(openai_msg) + + return openai_messages + + def _convert_tools_to_openai_format(self, tools: list) -> list[dict]: + """Convert tools to OpenAI format for tokenizer.""" + openai_tools = [] + for tool in tools: + if isinstance(tool, dict): + # Already a dict, ensure it's in the right format + if "function" in tool: + openai_tools.append(tool) + else: + # Might be the function directly + openai_tools.append({"type": "function", "function": tool}) + else: + # Convert object to dict + func = getattr(tool, "function", tool) + tool_dict = { + "type": "function", + "function": { + "name": getattr(func, "name", ""), + "description": getattr(func, "description", ""), + "parameters": getattr(func, "parameters", {}), + } + } + openai_tools.append(tool_dict) + return openai_tools + + def _format_messages_to_text(self, messages: list, tools: list) -> str: + """ + Format messages to text using tokenizer's apply_chat_template if available. + + Falls back to manual formatting if tokenizer is not available. + """ + tokenizer = self._get_tokenizer() + + if tokenizer is not None: + # Use tokenizer's apply_chat_template for proper formatting + openai_messages = self._convert_messages_to_openai_format(messages) + openai_tools = self._convert_tools_to_openai_format(tools) if tools else None + + try: + formatted = tokenizer.apply_chat_template( + openai_messages, + tokenize=False, + add_generation_prompt=True, + tools=openai_tools, + ) + logger.debug(f"Formatted prompt using tokenizer ({len(formatted)} chars)") + return formatted + except Exception as e: + logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting") + + # Fallback to manual formatting + return self._format_messages_to_text_manual(messages, tools) + + def _format_messages_to_text_manual(self, messages: list, tools: list) -> str: + """Manual fallback formatting for when tokenizer is not available.""" + formatted_parts = [] + tools_section = self._format_tools_for_prompt(tools) + + for msg in messages: + # Handle both dict and Pydantic Message objects + if hasattr(msg, 'role'): + role = msg.role + content = msg.content if hasattr(msg, 'content') else "" + if isinstance(content, list): + content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content]) + elif content is None: + content = "" + tool_calls = getattr(msg, 'tool_calls', None) + else: + role = msg.get("role", "user") + content = msg.get("content", "") + tool_calls = msg.get("tool_calls", None) + + if role == "system": + system_content = content + tools_section if tools_section else content + formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>") + tools_section = "" + elif role == "user": + formatted_parts.append(f"<|im_start|>user\n{content}<|im_end|>") + elif role == "assistant": + if tool_calls: + tc_parts = [] + for tc in tool_calls: + if hasattr(tc, 'function'): + tc_name = tc.function.name + tc_args = tc.function.arguments + else: + tc_name = tc.get("function", {}).get("name", "") + tc_args = tc.get("function", {}).get("arguments", "{}") + + if isinstance(tc_args, str): + try: + tc_args = json.loads(tc_args) + except: + pass + + tc_parts.append( + f"\n" + f'{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n' + f"" + ) + + assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts) + formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>") + elif content: + formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>") + elif role == "tool": + formatted_parts.append( + f"<|im_start|>user\n" + f"\n{content}\n<|im_end|>" + ) + + formatted_parts.append("<|im_start|>assistant\n") + return "\n".join(formatted_parts) + + def _parse_tool_calls(self, text: str) -> list[ToolCall]: + """ + Parse tool calls from response text. + + Looks for patterns like: + + {"name": "tool_name", "arguments": {...}} + + """ + tool_calls = [] + + # Find all tool_call blocks + pattern = r'\s*(\{.*?\})\s*' + matches = re.findall(pattern, text, re.DOTALL) + + for match in matches: + try: + tc_data = json.loads(match) + name = tc_data.get("name", "") + arguments = tc_data.get("arguments", {}) + + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + + tool_call = ToolCall( + id=f"call_{uuid.uuid4().hex[:8]}", + type="function", + function=FunctionCall( + name=name, + arguments=arguments, + ), + ) + tool_calls.append(tool_call) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse tool call JSON: {e}") + continue + + return tool_calls + + def _extract_content_without_tool_calls(self, text: str) -> str: + """Extract content from response, removing tool_call blocks.""" + # Remove tool_call blocks + cleaned = re.sub(r'.*?', '', text, flags=re.DOTALL) + # Clean up whitespace + cleaned = cleaned.strip() + return cleaned + + async def invoke_llm( + self, + request_data: dict, + messages: list, + tools: list, + use_assistant_message: bool, + requires_approval_tools: list[str] = [], + step_id: str | None = None, + actor: str | None = None, + ) -> AsyncGenerator[LettaMessage | None, None]: + """ + Execute LLM request using SGLang native endpoint. + + This method: + 1. Formats messages and tools to text using chat template + 2. Calls SGLang native /generate endpoint + 3. Extracts output_ids and output_token_logprobs + 4. Parses tool calls from response + 5. Converts response to standard format + """ + self.request_data = request_data + + # Get sampling params from request_data + sampling_params = { + "temperature": request_data.get("temperature", 0.7), + "max_new_tokens": request_data.get("max_tokens", 4096), + "top_p": request_data.get("top_p", 0.9), + } + + # Format messages to text (includes tools in prompt) + text_input = self._format_messages_to_text(messages, tools) + + # Call SGLang native endpoint + client = self._get_sglang_client() + + try: + response = await client.generate( + text=text_input, + sampling_params=sampling_params, + return_logprob=True, + ) + except Exception as e: + logger.error(f"SGLang native endpoint error: {e}") + raise + + self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns() + + # Store native response data + self.response_data = response + + # Extract SGLang native data + self.output_ids = response.get("output_ids") + # output_token_logprobs is inside meta_info + meta_info = response.get("meta_info", {}) + self.output_token_logprobs = meta_info.get("output_token_logprobs") + + # Extract text response + text_response = response.get("text", "") + + # Remove trailing end token if present + if text_response.endswith("<|im_end|>"): + text_response = text_response[:-10] + + # Parse tool calls from response + parsed_tool_calls = self._parse_tool_calls(text_response) + + # Extract content (text without tool_call blocks) + content_text = self._extract_content_without_tool_calls(text_response) + + # Determine finish reason + meta_info = response.get("meta_info", {}) + finish_reason_info = meta_info.get("finish_reason", {}) + if isinstance(finish_reason_info, dict): + finish_reason = finish_reason_info.get("type", "stop") + else: + finish_reason = "stop" + + # If we have tool calls, set finish_reason to tool_calls + if parsed_tool_calls: + finish_reason = "tool_calls" + + # Convert to standard ChatCompletionResponse format for compatibility + # Build logprobs in OpenAI format from SGLang format + logprobs_content = None + if self.output_token_logprobs: + logprobs_content = [] + for i, lp_data in enumerate(self.output_token_logprobs): + # SGLang format: [logprob, token_id, top_logprob] + logprob = lp_data[0] if len(lp_data) > 0 else 0.0 + token_id = lp_data[1] if len(lp_data) > 1 else 0 + logprobs_content.append( + ChatCompletionTokenLogprob( + token=str(token_id), + logprob=logprob, + bytes=None, + top_logprobs=[], + ) + ) + + choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None + + # Build chat completion response + prompt_tokens = meta_info.get("prompt_tokens", 0) + completion_tokens = len(self.output_ids) if self.output_ids else 0 + + self.chat_completions_response = ChatCompletionResponse( + id=meta_info.get("id", "sglang-native"), + created=int(time.time()), + choices=[ + Choice( + finish_reason=finish_reason, + index=0, + message=ChoiceMessage( + role="assistant", + content=content_text if content_text else None, + tool_calls=parsed_tool_calls if parsed_tool_calls else None, + ), + logprobs=choice_logprobs, + ) + ], + usage=UsageStatistics( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + # Extract content + if content_text: + self.content = [TextContent(text=content_text)] + else: + self.content = None + + # No reasoning content from native endpoint + self.reasoning_content = None + + # Set tool calls + self.tool_calls = parsed_tool_calls + self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None + + # Set logprobs + self.logprobs = choice_logprobs + + # Extract usage statistics + self.usage.step_count = 1 + self.usage.completion_tokens = completion_tokens + self.usage.prompt_tokens = prompt_tokens + self.usage.total_tokens = prompt_tokens + completion_tokens + + self.log_provider_trace(step_id=step_id, actor=actor) + + logger.info( + f"SGLang native response: {len(self.output_ids or [])} tokens, " + f"{len(self.output_token_logprobs or [])} logprobs, " + f"{len(parsed_tool_calls)} tool calls" + ) + + yield None + return diff --git a/letta/adapters/simple_llm_request_adapter.py b/letta/adapters/simple_llm_request_adapter.py index 7cec9472..8ab2a904 100644 --- a/letta/adapters/simple_llm_request_adapter.py +++ b/letta/adapters/simple_llm_request_adapter.py @@ -99,6 +99,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter): self.tool_calls = list(tool_calls) self.tool_call = self.tool_calls[0] if self.tool_calls else None + # Extract logprobs if present + self.logprobs = self.chat_completions_response.choices[0].logprobs + # Extract usage statistics self.usage.step_count = 1 self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index bd111eee..7c8942c9 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Literal, Optional from opentelemetry.trace import Span from letta.adapters.letta_llm_adapter import LettaLLMAdapter +from letta.adapters.sglang_native_adapter import SGLangNativeAdapter from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter from letta.agents.helpers import ( @@ -41,11 +42,11 @@ from letta.schemas.letta_message import ( ) from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_request import ClientToolSchema -from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_response import LettaResponse, TurnTokenData from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreate, ToolReturn -from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics +from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, FunctionCall, ToolCall, ToolCallDenial, UsageStatistics from letta.schemas.step import StepProgression from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool_execution_result import ToolExecutionResult @@ -121,6 +122,11 @@ class LettaAgentV3(LettaAgentV2): self.conversation_id: str | None = None # Client-side tools passed in the request (executed by client, not server) self.client_tools: list[ClientToolSchema] = [] + # Log probabilities from the most recent LLM call (for RL training) + self.logprobs: ChoiceLogprobs | None = None + # Multi-turn token tracking for RL training (accumulated across all LLM calls) + self.turns: list[TurnTokenData] = [] + self.return_token_ids: bool = False def _compute_tool_return_truncation_chars(self) -> int: """Compute a dynamic cap for tool returns in requests. @@ -196,6 +202,41 @@ class LettaAgentV3(LettaAgentV2): input_messages_to_persist = [input_messages_to_persist[0]] self.in_context_messages = curr_in_context_messages + + # Check if we should use SGLang native adapter for multi-turn RL training + use_sglang_native = ( + self.agent_state.llm_config.return_token_ids + and self.agent_state.llm_config.handle + and self.agent_state.llm_config.handle.startswith("sglang/") + ) + self.return_token_ids = use_sglang_native + + if use_sglang_native: + # Use SGLang native adapter for multi-turn RL training + llm_adapter = SGLangNativeAdapter( + llm_client=self.llm_client, + llm_config=self.agent_state.llm_config, + call_type=LLMCallType.agent_step, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, + ) + # Reset turns tracking for this step + self.turns = [] + else: + llm_adapter = SimpleLLMRequestAdapter( + llm_client=self.llm_client, + llm_config=self.agent_state.llm_config, + call_type=LLMCallType.agent_step, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, + ) + for i in range(max_steps): if i == 1 and follow_up_messages: input_messages_to_persist = follow_up_messages @@ -205,17 +246,7 @@ class LettaAgentV3(LettaAgentV2): # we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back) messages=list(self.in_context_messages + input_messages_to_persist), input_messages_to_persist=input_messages_to_persist, - # TODO need to support non-streaming adapter too - llm_adapter=SimpleLLMRequestAdapter( - llm_client=self.llm_client, - llm_config=self.agent_state.llm_config, - call_type=LLMCallType.agent_step, - agent_id=self.agent_state.id, - agent_tags=self.agent_state.tags, - run_id=run_id, - org_id=self.actor.organization_id, - user_id=self.actor.id, - ), + llm_adapter=llm_adapter, run_id=run_id, # use_assistant_message=use_assistant_message, include_return_message_types=include_return_message_types, @@ -292,7 +323,13 @@ class LettaAgentV3(LettaAgentV2): response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types] # Set context_tokens to expose actual context window usage (vs accumulated prompt_tokens) self.usage.context_tokens = self.context_token_estimate - result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage) + result = LettaResponse( + messages=response_letta_messages, + stop_reason=self.stop_reason, + usage=self.usage, + logprobs=self.logprobs, + turns=self.turns if self.return_token_ids and self.turns else None, + ) if run_id: if self.job_update_metadata is None: self.job_update_metadata = {} @@ -355,6 +392,14 @@ class LettaAgentV3(LettaAgentV2): actor=self.actor, ) + # Check if we should use SGLang native adapter for multi-turn RL training + use_sglang_native = ( + self.agent_state.llm_config.return_token_ids + and self.agent_state.llm_config.handle + and self.agent_state.llm_config.handle.startswith("sglang/") + ) + self.return_token_ids = use_sglang_native + if stream_tokens: llm_adapter = SimpleLLMStreamAdapter( llm_client=self.llm_client, @@ -366,6 +411,20 @@ class LettaAgentV3(LettaAgentV2): org_id=self.actor.organization_id, user_id=self.actor.id, ) + elif use_sglang_native: + # Use SGLang native adapter for multi-turn RL training + llm_adapter = SGLangNativeAdapter( + llm_client=self.llm_client, + llm_config=self.agent_state.llm_config, + call_type=LLMCallType.agent_step, + agent_id=self.agent_state.id, + agent_tags=self.agent_state.tags, + run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, + ) + # Reset turns tracking for this step + self.turns = [] else: llm_adapter = SimpleLLMRequestAdapter( llm_client=self.llm_client, @@ -488,7 +547,13 @@ class LettaAgentV3(LettaAgentV2): if run_id: # Filter out LettaStopReason from messages (only valid in LettaStreamingResponse, not LettaResponse) filtered_messages = [m for m in response_letta_messages if not isinstance(m, LettaStopReason)] - result = LettaResponse(messages=filtered_messages, stop_reason=self.stop_reason, usage=self.usage) + result = LettaResponse( + messages=filtered_messages, + stop_reason=self.stop_reason, + usage=self.usage, + logprobs=self.logprobs, + turns=self.turns if self.return_token_ids and self.turns else None, + ) if self.job_update_metadata is None: self.job_update_metadata = {} self.job_update_metadata["result"] = result.model_dump(mode="json") @@ -970,6 +1035,19 @@ class LettaAgentV3(LettaAgentV2): self.context_token_estimate = llm_adapter.usage.total_tokens self.logger.info(f"Context token estimate after LLM request: {self.context_token_estimate}") + # Extract logprobs if present (for RL training) + if llm_adapter.logprobs is not None: + self.logprobs = llm_adapter.logprobs + + # Track turn data for multi-turn RL training (SGLang native mode) + if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids: + self.turns.append(TurnTokenData( + role="assistant", + output_ids=llm_adapter.output_ids, + output_token_logprobs=llm_adapter.output_token_logprobs, + content=llm_adapter.chat_completions_response.choices[0].message.content if llm_adapter.chat_completions_response else None, + )) + # Handle the AI response with the extracted data (supports multiple tool calls) # Gather tool calls - check for multi-call API first, then fall back to single if hasattr(llm_adapter, "tool_calls") and llm_adapter.tool_calls: @@ -1015,6 +1093,34 @@ class LettaAgentV3(LettaAgentV2): # extend trackers with new messages self.response_messages.extend(new_messages) messages.extend(new_messages) + + # Track tool return turns for multi-turn RL training + if self.return_token_ids: + for msg in new_messages: + if msg.role == "tool": + # Get tool return content + tool_content = None + tool_name = None + if hasattr(msg, "tool_returns") and msg.tool_returns: + # Aggregate all tool returns into content (func_response is the actual content) + parts = [] + for tr in msg.tool_returns: + if hasattr(tr, 'func_response') and tr.func_response: + if isinstance(tr.func_response, str): + parts.append(tr.func_response) + else: + parts.append(str(tr.func_response)) + tool_content = "\n".join(parts) + elif hasattr(msg, "content") and msg.content: + tool_content = msg.content if isinstance(msg.content, str) else str(msg.content) + if hasattr(msg, "name"): + tool_name = msg.name + if tool_content: + self.turns.append(TurnTokenData( + role="tool", + content=tool_content, + tool_name=tool_name, + )) # step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics # persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 7257fc04..93ddc32b 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -511,6 +511,12 @@ class OpenAIClient(LLMClientBase): if llm_config.frequency_penalty is not None: data.frequency_penalty = llm_config.frequency_penalty + # Add logprobs configuration for RL training + if llm_config.return_logprobs: + data.logprobs = True + if llm_config.top_logprobs is not None: + data.top_logprobs = llm_config.top_logprobs + if tools and supports_parallel_tool_calling(model): data.parallel_tool_calls = False diff --git a/letta/llm_api/sglang_native_client.py b/letta/llm_api/sglang_native_client.py new file mode 100644 index 00000000..341be2c5 --- /dev/null +++ b/letta/llm_api/sglang_native_client.py @@ -0,0 +1,108 @@ +""" +SGLang Native Client for Letta. + +This client uses SGLang's native /generate endpoint instead of the OpenAI-compatible +/v1/chat/completions endpoint. The native endpoint returns token IDs and per-token +logprobs, which are essential for multi-turn RL training. + +The OpenAI-compatible endpoint only returns token strings, not IDs, making it +impossible to accurately reconstruct the token sequence for training. +""" + +from typing import Any, Dict, List, Optional + +import httpx + +from letta.log import get_logger + +logger = get_logger(__name__) + + +class SGLangNativeClient: + """Client for SGLang's native /generate endpoint. + + Unlike the OpenAI-compatible endpoint, this returns: + - output_ids: List of token IDs + - output_token_logprobs: List of [logprob, token_id, top_logprob] tuples + + This is essential for RL training where we need exact token IDs, not re-tokenized text. + """ + + def __init__(self, base_url: str, api_key: Optional[str] = None): + """ + Initialize the SGLang native client. + + Args: + base_url: Base URL for SGLang server (e.g., http://localhost:30000) + api_key: Optional API key for authentication + """ + # Remove /v1 suffix if present - native endpoint is at root + self.base_url = base_url.rstrip("/") + if self.base_url.endswith("/v1"): + self.base_url = self.base_url[:-3] + self.api_key = api_key + + async def generate( + self, + text: str, + sampling_params: Optional[Dict[str, Any]] = None, + return_logprob: bool = True, + ) -> Dict[str, Any]: + """ + Call SGLang's native /generate endpoint. + + Args: + text: The formatted prompt text (with chat template applied) + sampling_params: Sampling parameters (temperature, max_new_tokens, etc.) + return_logprob: Whether to return logprobs (default True for RL training) + + Returns: + Response dict with: + - text: Generated text + - output_ids: List of token IDs + - output_token_logprobs: List of [logprob, token_id, top_logprob] tuples + - meta_info: Metadata including finish_reason, prompt_tokens, etc. + + Example response: + { + "text": "Hello! How can I help?", + "output_ids": [9707, 0, 2585, 646, 358, 1492, 30], + "output_token_logprobs": [ + [-0.005, 9707, null], + [0.0, 0, null], + ... + ], + "meta_info": { + "finish_reason": {"type": "stop", "matched": 151645}, + "prompt_tokens": 42, + ... + } + } + """ + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + payload = { + "text": text, + "sampling_params": sampling_params or {}, + "return_logprob": return_logprob, + } + + async with httpx.AsyncClient(timeout=300.0) as client: + response = await client.post( + f"{self.base_url}/generate", + json=payload, + headers=headers, + ) + response.raise_for_status() + return response.json() + + async def health_check(self) -> bool: + """Check if the SGLang server is healthy.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(f"{self.base_url}/health") + return response.status_code == 200 + except Exception: + return False diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index e72aa333..fcff8c24 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -80,6 +80,25 @@ class LettaRequest(BaseModel): "If False (default), compaction messages are not included in the response.", ) + # Log probabilities for RL training + return_logprobs: bool = Field( + default=False, + description="If True, returns log probabilities of the output tokens in the response. " + "Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).", + ) + top_logprobs: Optional[int] = Field( + default=None, + description="Number of most likely tokens to return at each position (0-20). " + "Requires return_logprobs=True.", + ) + return_token_ids: bool = Field( + default=False, + description="If True, returns token IDs and logprobs for ALL LLM generations in the agent step, " + "not just the last one. Uses SGLang native /generate endpoint. " + "Returns 'turns' field with TurnTokenData for each assistant/tool turn. " + "Required for proper multi-turn RL training with loss masking.", + ) + @field_validator("messages", mode="before") @classmethod def add_default_type_to_messages(cls, v): diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 68ac2dc3..a964ab99 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -2,12 +2,13 @@ import html import json import re from datetime import datetime -from typing import List, Union +from typing import Any, List, Literal, Optional, Union from pydantic import BaseModel, Field, RootModel from letta.helpers.json_helpers import json_dumps from letta.schemas.enums import JobStatus, MessageStreamStatus +from letta.schemas.openai.chat_completion_response import ChoiceLogprobs from letta.schemas.letta_message import ( ApprovalRequestMessage, ApprovalResponseMessage, @@ -30,6 +31,35 @@ from letta.schemas.usage import LettaUsageStatistics # TODO: consider moving into own file +class TurnTokenData(BaseModel): + """Token data for a single LLM generation turn in a multi-turn agent interaction. + + Used for RL training to track token IDs and logprobs across all LLM calls, + not just the final one. Tool results are included so the client can tokenize + them with loss_mask=0 (non-trainable). + """ + role: Literal["assistant", "tool"] = Field( + ..., + description="Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)." + ) + output_ids: Optional[List[int]] = Field( + None, + description="Token IDs from SGLang native endpoint. Only present for assistant turns." + ) + output_token_logprobs: Optional[List[List[Any]]] = Field( + None, + description="Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns." + ) + content: Optional[str] = Field( + None, + description="Text content. For tool turns, client tokenizes this with loss_mask=0." + ) + tool_name: Optional[str] = Field( + None, + description="Name of the tool called. Only present for tool turns." + ) + + class LettaResponse(BaseModel): """ Response object from an agent interaction, consisting of the new messages generated by the agent and usage statistics. @@ -57,6 +87,16 @@ class LettaResponse(BaseModel): ..., description="The usage statistics of the agent.", ) + logprobs: Optional[ChoiceLogprobs] = Field( + None, + description="Log probabilities of the output tokens from the last LLM call. Only present if return_logprobs was enabled.", + ) + turns: Optional[List[TurnTokenData]] = Field( + None, + description="Token data for all LLM generations in multi-turn agent interaction. " + "Includes token IDs and logprobs for each assistant turn, plus tool result content. " + "Only present if return_token_ids was enabled. Used for RL training with loss masking.", + ) def __str__(self): return json_dumps( diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index e0953d40..5479898f 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -112,6 +112,19 @@ class LLMConfig(BaseModel): False, description="Enable strict mode for tool calling. When true, tool schemas include strict: true and additionalProperties: false, guaranteeing tool outputs match JSON schemas.", ) + return_logprobs: bool = Field( + False, + description="Whether to return log probabilities of the output tokens. Useful for RL training.", + ) + top_logprobs: Optional[int] = Field( + None, + description="Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True.", + ) + return_token_ids: bool = Field( + False, + description="Whether to return token IDs for all LLM generations via SGLang native endpoint. " + "Required for multi-turn RL training with loss masking. Only works with SGLang provider.", + ) @model_validator(mode="before") @classmethod diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 0adb1620..09832ecd 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1673,6 +1673,20 @@ async def send_message( result = None run_status = RunStatus.failed # Default to failed, updated on success try: + # Handle request-level logprobs override + if request.return_logprobs or request.return_token_ids: + agent = agent.model_copy( + update={ + "llm_config": agent.llm_config.model_copy( + update={ + "return_logprobs": request.return_logprobs, + "top_logprobs": request.top_logprobs, + "return_token_ids": request.return_token_ids, + } + ) + } + ) + agent_loop = AgentLoop.load(agent_state=agent, actor=actor) result = await agent_loop.step( request.messages, diff --git a/tests/managers/test_agent_manager.py b/tests/managers/test_agent_manager.py index 6e9921df..dcf2ddb3 100644 --- a/tests/managers/test_agent_manager.py +++ b/tests/managers/test_agent_manager.py @@ -1853,6 +1853,9 @@ async def test_agent_state_schema_unchanged(server: SyncServer): "tier", "parallel_tool_calls", "strict", + "return_logprobs", + "top_logprobs", + "return_token_ids", } actual_llm_config_fields = set(llm_config_fields.keys()) if actual_llm_config_fields != expected_llm_config_fields: