feat: add log probabilities from OpenAI-compatible servers and SGLang native endpoint (#9240)

* Add log probabilities support for RL training

This enables Letta server to request and return log probabilities from
OpenAI-compatible providers (including SGLang) for use in RL training.

Changes:
- LLMConfig: Add return_logprobs and top_logprobs fields
- OpenAIClient: Set logprobs in ChatCompletionRequest when enabled
- LettaLLMAdapter: Add logprobs field and extract from response
- LettaResponse: Add logprobs field to return log probs to client
- LettaRequest: Add return_logprobs/top_logprobs for per-request override
- LettaAgentV3: Store and pass logprobs through to response
- agents.py: Handle request-level logprobs override

Usage:
  response = client.agents.messages.create(
      agent_id=agent_id,
      messages=[...],
      return_logprobs=True,
      top_logprobs=5,
  )
  print(response.logprobs)  # Per-token log probabilities

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* Add multi-turn token tracking for RL training via SGLang native endpoint

- Add TurnTokenData schema to track token IDs and logprobs per turn
- Add return_token_ids flag to LettaRequest and LLMConfig
- Create SGLangNativeClient for /generate endpoint (returns output_ids)
- Create SGLangNativeAdapter that uses native endpoint
- Modify LettaAgentV3 to accumulate turns across LLM calls
- Include turns in LettaResponse when return_token_ids=True

* Fix: Add SGLang native adapter to step() method, not just stream()

* Fix: Handle Pydantic Message objects in SGLang native adapter

* Fix: Remove api_key reference from LLMConfig (not present)

* Fix: Add missing 'created' field to ChatCompletionResponse

* Add full tool support to SGLang native adapter

- Format tools into prompt in Qwen-style format
- Parse tool calls from <tool_call> tags in response
- Format tool results as <tool_response> in user messages
- Set finish_reason to 'tool_calls' when tools are called

* Use tokenizer.apply_chat_template for proper tool formatting

- Add tokenizer caching in SGLang native adapter
- Use apply_chat_template when tokenizer available
- Fall back to manual formatting if not
- Convert Letta messages to OpenAI format for tokenizer

* Fix: Use func_response instead of tool_return for ToolReturn content

* Fix: Get output_token_logprobs from meta_info in SGLang response

* Fix: Allow None in output_token_logprobs (SGLang format includes null)

* chore: remove unrelated files from logprobs branch

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: add missing call_type param to adapter constructors in letta_agent_v3

The SGLang refactor dropped call_type=LLMCallType.agent_step when extracting
adapter creation into conditional blocks. Restores it for all 3 spots (SGLang
in step, SimpleLLM in step, SGLang in stream).

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* just stage-api && just publish-api

* fix: update expected LLMConfig fields in schema test for logprobs support

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* chore: remove rllm provider references

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* just stage-api && just publish-api

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-65-206.ec2.internal>
Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Kevin Lin
2026-02-10 07:12:38 -08:00
committed by Caren Thomas
parent f9f1c55c93
commit 23c94ec6d3
13 changed files with 1305 additions and 103 deletions

View File

@@ -29192,43 +29192,6 @@
"title": "ChatCompletionSystemMessageParam",
"description": "Developer-provided instructions that the model should follow, regardless of\nmessages sent by the user. With o1 models and newer, use `developer` messages\nfor this purpose instead."
},
"ChatCompletionTokenLogprob": {
"properties": {
"token": {
"type": "string",
"title": "Token"
},
"bytes": {
"anyOf": [
{
"items": {
"type": "integer"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Bytes"
},
"logprob": {
"type": "number",
"title": "Logprob"
},
"top_logprobs": {
"items": {
"$ref": "#/components/schemas/TopLogprob"
},
"type": "array",
"title": "Top Logprobs"
}
},
"additionalProperties": true,
"type": "object",
"required": ["token", "logprob", "top_logprobs"],
"title": "ChatCompletionTokenLogprob"
},
"ChatCompletionToolMessageParam": {
"properties": {
"content": {
@@ -29453,7 +29416,7 @@
"logprobs": {
"anyOf": [
{
"$ref": "#/components/schemas/ChoiceLogprobs"
"$ref": "#/components/schemas/openai__types__chat__chat_completion__ChoiceLogprobs"
},
{
"type": "null"
@@ -29469,42 +29432,6 @@
"required": ["finish_reason", "index", "message"],
"title": "Choice"
},
"ChoiceLogprobs": {
"properties": {
"content": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ChatCompletionTokenLogprob"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Content"
},
"refusal": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ChatCompletionTokenLogprob"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Refusal"
}
},
"additionalProperties": true,
"type": "object",
"title": "ChoiceLogprobs",
"description": "Log probability information for the choice."
},
"ClientToolSchema": {
"properties": {
"name": {
@@ -30525,6 +30452,30 @@
"description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.",
"default": false
},
"streaming": {
"type": "boolean",
"title": "Streaming",
@@ -36716,6 +36667,30 @@
"title": "Strict",
"description": "Enable strict mode for tool calling. When true, tool schemas include strict: true and additionalProperties: false, guaranteeing tool outputs match JSON schemas.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "Whether to return log probabilities of the output tokens. Useful for RL training.",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "Whether to return token IDs for all LLM generations via SGLang native endpoint. Required for multi-turn RL training with loss masking. Only works with SGLang provider.",
"default": false
}
},
"type": "object",
@@ -36888,6 +36863,30 @@
"description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.",
"default": false
},
"callback_url": {
"anyOf": [
{
@@ -37083,6 +37082,30 @@
"description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.",
"default": false
},
"agent_id": {
"type": "string",
"maxLength": 42,
@@ -37457,6 +37480,30 @@
"title": "Include Compaction Messages",
"description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.",
"default": false
}
},
"type": "object",
@@ -37517,6 +37564,32 @@
"usage": {
"$ref": "#/components/schemas/LettaUsageStatistics",
"description": "The usage statistics of the agent."
},
"logprobs": {
"anyOf": [
{
"$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__ChoiceLogprobs"
},
{
"type": "null"
}
],
"description": "Log probabilities of the output tokens from the last LLM call. Only present if return_logprobs was enabled."
},
"turns": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/TurnTokenData"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Turns",
"description": "Token data for all LLM generations in multi-turn agent interaction. Includes token IDs and logprobs for each assistant turn, plus tool result content. Only present if return_token_ids was enabled. Used for RL training with loss masking."
}
},
"type": "object",
@@ -37708,6 +37781,30 @@
"description": "If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. If False (default), compaction messages are not included in the response.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "If True, returns log probabilities of the output tokens in the response. Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "If True, returns token IDs and logprobs for ALL LLM generations in the agent step, not just the last one. Uses SGLang native /generate endpoint. Returns 'turns' field with TurnTokenData for each assistant/tool turn. Required for proper multi-turn RL training with loss masking.",
"default": false
},
"streaming": {
"type": "boolean",
"title": "Streaming",
@@ -39265,6 +39362,30 @@
"description": "Enable strict mode for tool calling. When true, tool schemas include strict: true and additionalProperties: false, guaranteeing tool outputs match JSON schemas.",
"default": false
},
"return_logprobs": {
"type": "boolean",
"title": "Return Logprobs",
"description": "Whether to return log probabilities of the output tokens. Useful for RL training.",
"default": false
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Top Logprobs",
"description": "Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True."
},
"return_token_ids": {
"type": "boolean",
"title": "Return Token Ids",
"description": "Whether to return token IDs for all LLM generations via SGLang native endpoint. Required for multi-turn RL training with loss masking. Only works with SGLang provider.",
"default": false
},
"max_context_window": {
"type": "integer",
"title": "Max Context Window",
@@ -45313,13 +45434,15 @@
"type": "object",
"title": "ToolUpdate"
},
"TopLogprob": {
"TurnTokenData": {
"properties": {
"token": {
"role": {
"type": "string",
"title": "Token"
"enum": ["assistant", "tool"],
"title": "Role",
"description": "Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)."
},
"bytes": {
"output_ids": {
"anyOf": [
{
"items": {
@@ -45331,17 +45454,54 @@
"type": "null"
}
],
"title": "Bytes"
"title": "Output Ids",
"description": "Token IDs from SGLang native endpoint. Only present for assistant turns."
},
"logprob": {
"type": "number",
"title": "Logprob"
"output_token_logprobs": {
"anyOf": [
{
"items": {
"items": {},
"type": "array"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Output Token Logprobs",
"description": "Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns."
},
"content": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Content",
"description": "Text content. For tool turns, client tokenizes this with loss_mask=0."
},
"tool_name": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Tool Name",
"description": "Name of the tool called. Only present for tool turns."
}
},
"additionalProperties": true,
"type": "object",
"required": ["token", "logprob"],
"title": "TopLogprob"
"required": ["role"],
"title": "TurnTokenData",
"description": "Token data for a single LLM generation turn in a multi-turn agent interaction.\n\nUsed for RL training to track token IDs and logprobs across all LLM calls,\nnot just the final one. Tool results are included so the client can tokenize\nthem with loss_mask=0 (non-trainable)."
},
"UpdateAgent": {
"properties": {
@@ -48653,6 +48813,105 @@
"required": ["status"],
"title": "ToolReturn"
},
"letta__schemas__openai__chat_completion_response__ChatCompletionTokenLogprob": {
"properties": {
"token": {
"type": "string",
"title": "Token"
},
"bytes": {
"anyOf": [
{
"items": {
"type": "integer"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Bytes"
},
"logprob": {
"type": "number",
"title": "Logprob"
},
"top_logprobs": {
"items": {
"$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__TopLogprob"
},
"type": "array",
"title": "Top Logprobs"
}
},
"type": "object",
"required": ["token", "logprob", "top_logprobs"],
"title": "ChatCompletionTokenLogprob"
},
"letta__schemas__openai__chat_completion_response__ChoiceLogprobs": {
"properties": {
"content": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__ChatCompletionTokenLogprob"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Content"
},
"refusal": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/letta__schemas__openai__chat_completion_response__ChatCompletionTokenLogprob"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Refusal"
}
},
"type": "object",
"title": "ChoiceLogprobs"
},
"letta__schemas__openai__chat_completion_response__TopLogprob": {
"properties": {
"token": {
"type": "string",
"title": "Token"
},
"bytes": {
"anyOf": [
{
"items": {
"type": "integer"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Bytes"
},
"logprob": {
"type": "number",
"title": "Logprob"
}
},
"type": "object",
"required": ["token", "logprob"],
"title": "TopLogprob"
},
"letta__serialize_schemas__pydantic_agent_schema__AgentSchema": {
"properties": {
"agent_type": {
@@ -48999,6 +49258,42 @@
"type": "object",
"title": "ToolExecuteRequest"
},
"openai__types__chat__chat_completion__ChoiceLogprobs": {
"properties": {
"content": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/openai__types__chat__chat_completion_token_logprob__ChatCompletionTokenLogprob"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Content"
},
"refusal": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/openai__types__chat__chat_completion_token_logprob__ChatCompletionTokenLogprob"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Refusal"
}
},
"additionalProperties": true,
"type": "object",
"title": "ChoiceLogprobs",
"description": "Log probability information for the choice."
},
"openai__types__chat__chat_completion_message_function_tool_call__Function": {
"properties": {
"arguments": {
@@ -49032,6 +49327,73 @@
"title": "Function",
"description": "The function that the model called."
},
"openai__types__chat__chat_completion_token_logprob__ChatCompletionTokenLogprob": {
"properties": {
"token": {
"type": "string",
"title": "Token"
},
"bytes": {
"anyOf": [
{
"items": {
"type": "integer"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Bytes"
},
"logprob": {
"type": "number",
"title": "Logprob"
},
"top_logprobs": {
"items": {
"$ref": "#/components/schemas/openai__types__chat__chat_completion_token_logprob__TopLogprob"
},
"type": "array",
"title": "Top Logprobs"
}
},
"additionalProperties": true,
"type": "object",
"required": ["token", "logprob", "top_logprobs"],
"title": "ChatCompletionTokenLogprob"
},
"openai__types__chat__chat_completion_token_logprob__TopLogprob": {
"properties": {
"token": {
"type": "string",
"title": "Token"
},
"bytes": {
"anyOf": [
{
"items": {
"type": "integer"
},
"type": "array"
},
{
"type": "null"
}
],
"title": "Bytes"
},
"logprob": {
"type": "number",
"title": "Logprob"
}
},
"additionalProperties": true,
"type": "object",
"required": ["token", "logprob"],
"title": "TopLogprob"
},
"LettaMessageUnion": {
"oneOf": [
{

View File

@@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
from letta.llm_api.llm_client_base import LLMClientBase
from letta.schemas.enums import LLMCallType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ToolCall
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.telemetry_manager import TelemetryManager
@@ -48,6 +48,10 @@ class LettaLLMAdapter(ABC):
self.content: list[TextContent | ReasoningContent | RedactedReasoningContent] | None = None
self.tool_call: ToolCall | None = None
self.tool_calls: list[ToolCall] = []
self.logprobs: ChoiceLogprobs | None = None
# SGLang native endpoint data (for multi-turn RL training)
self.output_ids: list[int] | None = None
self.output_token_logprobs: list[list[float]] | None = None
self.usage: LettaUsageStatistics = LettaUsageStatistics()
self.telemetry_manager: TelemetryManager = TelemetryManager()
self.llm_request_finish_timestamp_ns: int | None = None

View File

@@ -83,6 +83,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
else:
self.tool_call = None
# Extract logprobs if present
self.logprobs = self.chat_completions_response.choices[0].logprobs
# Extract usage statistics
self.usage.step_count = 1
self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens

View File

@@ -0,0 +1,521 @@
"""
SGLang Native Adapter for multi-turn RL training.
This adapter uses SGLang's native /generate endpoint instead of the OpenAI-compatible
endpoint to get token IDs and per-token logprobs, which are essential for proper
multi-turn RL training with loss masking.
Uses HuggingFace tokenizer's apply_chat_template() for proper tool formatting.
"""
import json
import re
import time
import uuid
from typing import Any, AsyncGenerator, Optional
from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.llm_api.sglang_native_client import SGLangNativeClient
from letta.log import get_logger
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
from letta.schemas.openai.chat_completion_response import (
ChatCompletionResponse,
Choice,
ChoiceLogprobs,
ChatCompletionTokenLogprob,
FunctionCall,
Message as ChoiceMessage,
ToolCall,
UsageStatistics,
)
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
logger = get_logger(__name__)
# Global tokenizer cache
_tokenizer_cache: dict[str, Any] = {}
class SGLangNativeAdapter(SimpleLLMRequestAdapter):
"""
Adapter that uses SGLang's native /generate endpoint for multi-turn RL training.
Key differences from SimpleLLMRequestAdapter:
- Uses /generate instead of /v1/chat/completions
- Returns output_ids (token IDs) in addition to text
- Returns output_token_logprobs with [logprob, token_id] pairs
- Formats tools into prompt and parses tool calls from response
These are essential for building accurate loss masks in multi-turn training.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._sglang_client: Optional[SGLangNativeClient] = None
self._tokenizer: Any = None
def _get_tokenizer(self) -> Any:
"""Get or create tokenizer for the model."""
global _tokenizer_cache
# Get model name from llm_config
model_name = self.llm_config.model
if not model_name:
logger.warning("No model name in llm_config, cannot load tokenizer")
return None
# Check cache
if model_name in _tokenizer_cache:
return _tokenizer_cache[model_name]
try:
from transformers import AutoTokenizer
logger.info(f"Loading tokenizer for model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
_tokenizer_cache[model_name] = tokenizer
return tokenizer
except ImportError:
logger.warning("transformers not installed, falling back to manual formatting")
return None
except Exception as e:
logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting")
return None
def _get_sglang_client(self) -> SGLangNativeClient:
"""Get or create SGLang native client."""
if self._sglang_client is None:
# Get base URL from llm_config, removing /v1 suffix if present
base_url = self.llm_config.model_endpoint or ""
# SGLang local instances typically don't need API key
self._sglang_client = SGLangNativeClient(
base_url=base_url,
api_key=None,
)
return self._sglang_client
def _format_tools_for_prompt(self, tools: list) -> str:
"""
Format tools in Qwen3 chat template format for the system prompt.
This matches the exact format produced by Qwen3's tokenizer.apply_chat_template()
with tools parameter.
"""
if not tools:
return ""
# Format each tool as JSON (matching Qwen3 template exactly)
tool_jsons = []
for tool in tools:
# Handle both dict and object formats
if isinstance(tool, dict):
# Already in OpenAI format
tool_jsons.append(json.dumps(tool))
else:
# Convert object to dict
tool_dict = {
"type": "function",
"function": {
"name": getattr(getattr(tool, "function", tool), "name", ""),
"description": getattr(getattr(tool, "function", tool), "description", ""),
"parameters": getattr(getattr(tool, "function", tool), "parameters", {}),
}
}
tool_jsons.append(json.dumps(tool_dict))
# Use exact Qwen3 format
tools_section = (
"\n\n# Tools\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
"<tools>\n"
+ "\n".join(tool_jsons) + "\n"
"</tools>\n\n"
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
"<tool_call>\n"
'{"name": <function-name>, "arguments": <args-json-object>}\n'
"</tool_call>"
)
return tools_section
def _convert_messages_to_openai_format(self, messages: list) -> list[dict]:
"""Convert Letta Message objects to OpenAI-style message dicts."""
openai_messages = []
for msg in messages:
# Handle both dict and Pydantic Message objects
if hasattr(msg, 'role'):
role = msg.role
content = msg.content if hasattr(msg, 'content') else ""
# Handle content that might be a list of content parts
if isinstance(content, list):
content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content])
elif content is None:
content = ""
tool_calls = getattr(msg, 'tool_calls', None)
tool_call_id = getattr(msg, 'tool_call_id', None)
name = getattr(msg, 'name', None)
else:
role = msg.get("role", "user")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", None)
tool_call_id = msg.get("tool_call_id", None)
name = msg.get("name", None)
openai_msg = {"role": role, "content": content}
if tool_calls:
# Convert tool calls to OpenAI format
openai_tool_calls = []
for tc in tool_calls:
if hasattr(tc, 'function'):
tc_dict = {
"id": getattr(tc, 'id', f"call_{uuid.uuid4().hex[:8]}"),
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments if isinstance(tc.function.arguments, str) else json.dumps(tc.function.arguments)
}
}
else:
tc_dict = {
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
"type": "function",
"function": tc.get("function", {})
}
openai_tool_calls.append(tc_dict)
openai_msg["tool_calls"] = openai_tool_calls
if tool_call_id:
openai_msg["tool_call_id"] = tool_call_id
if name and role == "tool":
openai_msg["name"] = name
openai_messages.append(openai_msg)
return openai_messages
def _convert_tools_to_openai_format(self, tools: list) -> list[dict]:
"""Convert tools to OpenAI format for tokenizer."""
openai_tools = []
for tool in tools:
if isinstance(tool, dict):
# Already a dict, ensure it's in the right format
if "function" in tool:
openai_tools.append(tool)
else:
# Might be the function directly
openai_tools.append({"type": "function", "function": tool})
else:
# Convert object to dict
func = getattr(tool, "function", tool)
tool_dict = {
"type": "function",
"function": {
"name": getattr(func, "name", ""),
"description": getattr(func, "description", ""),
"parameters": getattr(func, "parameters", {}),
}
}
openai_tools.append(tool_dict)
return openai_tools
def _format_messages_to_text(self, messages: list, tools: list) -> str:
"""
Format messages to text using tokenizer's apply_chat_template if available.
Falls back to manual formatting if tokenizer is not available.
"""
tokenizer = self._get_tokenizer()
if tokenizer is not None:
# Use tokenizer's apply_chat_template for proper formatting
openai_messages = self._convert_messages_to_openai_format(messages)
openai_tools = self._convert_tools_to_openai_format(tools) if tools else None
try:
formatted = tokenizer.apply_chat_template(
openai_messages,
tokenize=False,
add_generation_prompt=True,
tools=openai_tools,
)
logger.debug(f"Formatted prompt using tokenizer ({len(formatted)} chars)")
return formatted
except Exception as e:
logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting")
# Fallback to manual formatting
return self._format_messages_to_text_manual(messages, tools)
def _format_messages_to_text_manual(self, messages: list, tools: list) -> str:
"""Manual fallback formatting for when tokenizer is not available."""
formatted_parts = []
tools_section = self._format_tools_for_prompt(tools)
for msg in messages:
# Handle both dict and Pydantic Message objects
if hasattr(msg, 'role'):
role = msg.role
content = msg.content if hasattr(msg, 'content') else ""
if isinstance(content, list):
content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content])
elif content is None:
content = ""
tool_calls = getattr(msg, 'tool_calls', None)
else:
role = msg.get("role", "user")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", None)
if role == "system":
system_content = content + tools_section if tools_section else content
formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>")
tools_section = ""
elif role == "user":
formatted_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
if tool_calls:
tc_parts = []
for tc in tool_calls:
if hasattr(tc, 'function'):
tc_name = tc.function.name
tc_args = tc.function.arguments
else:
tc_name = tc.get("function", {}).get("name", "")
tc_args = tc.get("function", {}).get("arguments", "{}")
if isinstance(tc_args, str):
try:
tc_args = json.loads(tc_args)
except:
pass
tc_parts.append(
f"<tool_call>\n"
f'{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n'
f"</tool_call>"
)
assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts)
formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>")
elif content:
formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
elif role == "tool":
formatted_parts.append(
f"<|im_start|>user\n"
f"<tool_response>\n{content}\n</tool_response><|im_end|>"
)
formatted_parts.append("<|im_start|>assistant\n")
return "\n".join(formatted_parts)
def _parse_tool_calls(self, text: str) -> list[ToolCall]:
"""
Parse tool calls from response text.
Looks for patterns like:
<tool_call>
{"name": "tool_name", "arguments": {...}}
</tool_call>
"""
tool_calls = []
# Find all tool_call blocks
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
matches = re.findall(pattern, text, re.DOTALL)
for match in matches:
try:
tc_data = json.loads(match)
name = tc_data.get("name", "")
arguments = tc_data.get("arguments", {})
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
tool_call = ToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
type="function",
function=FunctionCall(
name=name,
arguments=arguments,
),
)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse tool call JSON: {e}")
continue
return tool_calls
def _extract_content_without_tool_calls(self, text: str) -> str:
"""Extract content from response, removing tool_call blocks."""
# Remove tool_call blocks
cleaned = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL)
# Clean up whitespace
cleaned = cleaned.strip()
return cleaned
async def invoke_llm(
self,
request_data: dict,
messages: list,
tools: list,
use_assistant_message: bool,
requires_approval_tools: list[str] = [],
step_id: str | None = None,
actor: str | None = None,
) -> AsyncGenerator[LettaMessage | None, None]:
"""
Execute LLM request using SGLang native endpoint.
This method:
1. Formats messages and tools to text using chat template
2. Calls SGLang native /generate endpoint
3. Extracts output_ids and output_token_logprobs
4. Parses tool calls from response
5. Converts response to standard format
"""
self.request_data = request_data
# Get sampling params from request_data
sampling_params = {
"temperature": request_data.get("temperature", 0.7),
"max_new_tokens": request_data.get("max_tokens", 4096),
"top_p": request_data.get("top_p", 0.9),
}
# Format messages to text (includes tools in prompt)
text_input = self._format_messages_to_text(messages, tools)
# Call SGLang native endpoint
client = self._get_sglang_client()
try:
response = await client.generate(
text=text_input,
sampling_params=sampling_params,
return_logprob=True,
)
except Exception as e:
logger.error(f"SGLang native endpoint error: {e}")
raise
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
# Store native response data
self.response_data = response
# Extract SGLang native data
self.output_ids = response.get("output_ids")
# output_token_logprobs is inside meta_info
meta_info = response.get("meta_info", {})
self.output_token_logprobs = meta_info.get("output_token_logprobs")
# Extract text response
text_response = response.get("text", "")
# Remove trailing end token if present
if text_response.endswith("<|im_end|>"):
text_response = text_response[:-10]
# Parse tool calls from response
parsed_tool_calls = self._parse_tool_calls(text_response)
# Extract content (text without tool_call blocks)
content_text = self._extract_content_without_tool_calls(text_response)
# Determine finish reason
meta_info = response.get("meta_info", {})
finish_reason_info = meta_info.get("finish_reason", {})
if isinstance(finish_reason_info, dict):
finish_reason = finish_reason_info.get("type", "stop")
else:
finish_reason = "stop"
# If we have tool calls, set finish_reason to tool_calls
if parsed_tool_calls:
finish_reason = "tool_calls"
# Convert to standard ChatCompletionResponse format for compatibility
# Build logprobs in OpenAI format from SGLang format
logprobs_content = None
if self.output_token_logprobs:
logprobs_content = []
for i, lp_data in enumerate(self.output_token_logprobs):
# SGLang format: [logprob, token_id, top_logprob]
logprob = lp_data[0] if len(lp_data) > 0 else 0.0
token_id = lp_data[1] if len(lp_data) > 1 else 0
logprobs_content.append(
ChatCompletionTokenLogprob(
token=str(token_id),
logprob=logprob,
bytes=None,
top_logprobs=[],
)
)
choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None
# Build chat completion response
prompt_tokens = meta_info.get("prompt_tokens", 0)
completion_tokens = len(self.output_ids) if self.output_ids else 0
self.chat_completions_response = ChatCompletionResponse(
id=meta_info.get("id", "sglang-native"),
created=int(time.time()),
choices=[
Choice(
finish_reason=finish_reason,
index=0,
message=ChoiceMessage(
role="assistant",
content=content_text if content_text else None,
tool_calls=parsed_tool_calls if parsed_tool_calls else None,
),
logprobs=choice_logprobs,
)
],
usage=UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
# Extract content
if content_text:
self.content = [TextContent(text=content_text)]
else:
self.content = None
# No reasoning content from native endpoint
self.reasoning_content = None
# Set tool calls
self.tool_calls = parsed_tool_calls
self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None
# Set logprobs
self.logprobs = choice_logprobs
# Extract usage statistics
self.usage.step_count = 1
self.usage.completion_tokens = completion_tokens
self.usage.prompt_tokens = prompt_tokens
self.usage.total_tokens = prompt_tokens + completion_tokens
self.log_provider_trace(step_id=step_id, actor=actor)
logger.info(
f"SGLang native response: {len(self.output_ids or [])} tokens, "
f"{len(self.output_token_logprobs or [])} logprobs, "
f"{len(parsed_tool_calls)} tool calls"
)
yield None
return

View File

@@ -99,6 +99,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
self.tool_calls = list(tool_calls)
self.tool_call = self.tool_calls[0] if self.tool_calls else None
# Extract logprobs if present
self.logprobs = self.chat_completions_response.choices[0].logprobs
# Extract usage statistics
self.usage.step_count = 1
self.usage.completion_tokens = self.chat_completions_response.usage.completion_tokens

View File

@@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Literal, Optional
from opentelemetry.trace import Span
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
from letta.adapters.sglang_native_adapter import SGLangNativeAdapter
from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
from letta.agents.helpers import (
@@ -41,11 +42,11 @@ from letta.schemas.letta_message import (
)
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_response import LettaResponse, TurnTokenData
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -121,6 +122,11 @@ class LettaAgentV3(LettaAgentV2):
self.conversation_id: str | None = None
# Client-side tools passed in the request (executed by client, not server)
self.client_tools: list[ClientToolSchema] = []
# Log probabilities from the most recent LLM call (for RL training)
self.logprobs: ChoiceLogprobs | None = None
# Multi-turn token tracking for RL training (accumulated across all LLM calls)
self.turns: list[TurnTokenData] = []
self.return_token_ids: bool = False
def _compute_tool_return_truncation_chars(self) -> int:
"""Compute a dynamic cap for tool returns in requests.
@@ -196,6 +202,41 @@ class LettaAgentV3(LettaAgentV2):
input_messages_to_persist = [input_messages_to_persist[0]]
self.in_context_messages = curr_in_context_messages
# Check if we should use SGLang native adapter for multi-turn RL training
use_sglang_native = (
self.agent_state.llm_config.return_token_ids
and self.agent_state.llm_config.handle
and self.agent_state.llm_config.handle.startswith("sglang/")
)
self.return_token_ids = use_sglang_native
if use_sglang_native:
# Use SGLang native adapter for multi-turn RL training
llm_adapter = SGLangNativeAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
# Reset turns tracking for this step
self.turns = []
else:
llm_adapter = SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
for i in range(max_steps):
if i == 1 and follow_up_messages:
input_messages_to_persist = follow_up_messages
@@ -205,17 +246,7 @@ class LettaAgentV3(LettaAgentV2):
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
messages=list(self.in_context_messages + input_messages_to_persist),
input_messages_to_persist=input_messages_to_persist,
# TODO need to support non-streaming adapter too
llm_adapter=SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
llm_adapter=llm_adapter,
run_id=run_id,
# use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
@@ -292,7 +323,13 @@ class LettaAgentV3(LettaAgentV2):
response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types]
# Set context_tokens to expose actual context window usage (vs accumulated prompt_tokens)
self.usage.context_tokens = self.context_token_estimate
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
result = LettaResponse(
messages=response_letta_messages,
stop_reason=self.stop_reason,
usage=self.usage,
logprobs=self.logprobs,
turns=self.turns if self.return_token_ids and self.turns else None,
)
if run_id:
if self.job_update_metadata is None:
self.job_update_metadata = {}
@@ -355,6 +392,14 @@ class LettaAgentV3(LettaAgentV2):
actor=self.actor,
)
# Check if we should use SGLang native adapter for multi-turn RL training
use_sglang_native = (
self.agent_state.llm_config.return_token_ids
and self.agent_state.llm_config.handle
and self.agent_state.llm_config.handle.startswith("sglang/")
)
self.return_token_ids = use_sglang_native
if stream_tokens:
llm_adapter = SimpleLLMStreamAdapter(
llm_client=self.llm_client,
@@ -366,6 +411,20 @@ class LettaAgentV3(LettaAgentV2):
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
elif use_sglang_native:
# Use SGLang native adapter for multi-turn RL training
llm_adapter = SGLangNativeAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
# Reset turns tracking for this step
self.turns = []
else:
llm_adapter = SimpleLLMRequestAdapter(
llm_client=self.llm_client,
@@ -488,7 +547,13 @@ class LettaAgentV3(LettaAgentV2):
if run_id:
# Filter out LettaStopReason from messages (only valid in LettaStreamingResponse, not LettaResponse)
filtered_messages = [m for m in response_letta_messages if not isinstance(m, LettaStopReason)]
result = LettaResponse(messages=filtered_messages, stop_reason=self.stop_reason, usage=self.usage)
result = LettaResponse(
messages=filtered_messages,
stop_reason=self.stop_reason,
usage=self.usage,
logprobs=self.logprobs,
turns=self.turns if self.return_token_ids and self.turns else None,
)
if self.job_update_metadata is None:
self.job_update_metadata = {}
self.job_update_metadata["result"] = result.model_dump(mode="json")
@@ -970,6 +1035,19 @@ class LettaAgentV3(LettaAgentV2):
self.context_token_estimate = llm_adapter.usage.total_tokens
self.logger.info(f"Context token estimate after LLM request: {self.context_token_estimate}")
# Extract logprobs if present (for RL training)
if llm_adapter.logprobs is not None:
self.logprobs = llm_adapter.logprobs
# Track turn data for multi-turn RL training (SGLang native mode)
if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids:
self.turns.append(TurnTokenData(
role="assistant",
output_ids=llm_adapter.output_ids,
output_token_logprobs=llm_adapter.output_token_logprobs,
content=llm_adapter.chat_completions_response.choices[0].message.content if llm_adapter.chat_completions_response else None,
))
# Handle the AI response with the extracted data (supports multiple tool calls)
# Gather tool calls - check for multi-call API first, then fall back to single
if hasattr(llm_adapter, "tool_calls") and llm_adapter.tool_calls:
@@ -1015,6 +1093,34 @@ class LettaAgentV3(LettaAgentV2):
# extend trackers with new messages
self.response_messages.extend(new_messages)
messages.extend(new_messages)
# Track tool return turns for multi-turn RL training
if self.return_token_ids:
for msg in new_messages:
if msg.role == "tool":
# Get tool return content
tool_content = None
tool_name = None
if hasattr(msg, "tool_returns") and msg.tool_returns:
# Aggregate all tool returns into content (func_response is the actual content)
parts = []
for tr in msg.tool_returns:
if hasattr(tr, 'func_response') and tr.func_response:
if isinstance(tr.func_response, str):
parts.append(tr.func_response)
else:
parts.append(str(tr.func_response))
tool_content = "\n".join(parts)
elif hasattr(msg, "content") and msg.content:
tool_content = msg.content if isinstance(msg.content, str) else str(msg.content)
if hasattr(msg, "name"):
tool_name = msg.name
if tool_content:
self.turns.append(TurnTokenData(
role="tool",
content=tool_content,
tool_name=tool_name,
))
# step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics
# persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state

View File

@@ -511,6 +511,12 @@ class OpenAIClient(LLMClientBase):
if llm_config.frequency_penalty is not None:
data.frequency_penalty = llm_config.frequency_penalty
# Add logprobs configuration for RL training
if llm_config.return_logprobs:
data.logprobs = True
if llm_config.top_logprobs is not None:
data.top_logprobs = llm_config.top_logprobs
if tools and supports_parallel_tool_calling(model):
data.parallel_tool_calls = False

View File

@@ -0,0 +1,108 @@
"""
SGLang Native Client for Letta.
This client uses SGLang's native /generate endpoint instead of the OpenAI-compatible
/v1/chat/completions endpoint. The native endpoint returns token IDs and per-token
logprobs, which are essential for multi-turn RL training.
The OpenAI-compatible endpoint only returns token strings, not IDs, making it
impossible to accurately reconstruct the token sequence for training.
"""
from typing import Any, Dict, List, Optional
import httpx
from letta.log import get_logger
logger = get_logger(__name__)
class SGLangNativeClient:
"""Client for SGLang's native /generate endpoint.
Unlike the OpenAI-compatible endpoint, this returns:
- output_ids: List of token IDs
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
This is essential for RL training where we need exact token IDs, not re-tokenized text.
"""
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the SGLang native client.
Args:
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
api_key: Optional API key for authentication
"""
# Remove /v1 suffix if present - native endpoint is at root
self.base_url = base_url.rstrip("/")
if self.base_url.endswith("/v1"):
self.base_url = self.base_url[:-3]
self.api_key = api_key
async def generate(
self,
text: str,
sampling_params: Optional[Dict[str, Any]] = None,
return_logprob: bool = True,
) -> Dict[str, Any]:
"""
Call SGLang's native /generate endpoint.
Args:
text: The formatted prompt text (with chat template applied)
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
return_logprob: Whether to return logprobs (default True for RL training)
Returns:
Response dict with:
- text: Generated text
- output_ids: List of token IDs
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
Example response:
{
"text": "Hello! How can I help?",
"output_ids": [9707, 0, 2585, 646, 358, 1492, 30],
"output_token_logprobs": [
[-0.005, 9707, null],
[0.0, 0, null],
...
],
"meta_info": {
"finish_reason": {"type": "stop", "matched": 151645},
"prompt_tokens": 42,
...
}
}
"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"text": text,
"sampling_params": sampling_params or {},
"return_logprob": return_logprob,
}
async with httpx.AsyncClient(timeout=300.0) as client:
response = await client.post(
f"{self.base_url}/generate",
json=payload,
headers=headers,
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""Check if the SGLang server is healthy."""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{self.base_url}/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -80,6 +80,25 @@ class LettaRequest(BaseModel):
"If False (default), compaction messages are not included in the response.",
)
# Log probabilities for RL training
return_logprobs: bool = Field(
default=False,
description="If True, returns log probabilities of the output tokens in the response. "
"Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
)
top_logprobs: Optional[int] = Field(
default=None,
description="Number of most likely tokens to return at each position (0-20). "
"Requires return_logprobs=True.",
)
return_token_ids: bool = Field(
default=False,
description="If True, returns token IDs and logprobs for ALL LLM generations in the agent step, "
"not just the last one. Uses SGLang native /generate endpoint. "
"Returns 'turns' field with TurnTokenData for each assistant/tool turn. "
"Required for proper multi-turn RL training with loss masking.",
)
@field_validator("messages", mode="before")
@classmethod
def add_default_type_to_messages(cls, v):

View File

@@ -2,12 +2,13 @@ import html
import json
import re
from datetime import datetime
from typing import List, Union
from typing import Any, List, Literal, Optional, Union
from pydantic import BaseModel, Field, RootModel
from letta.helpers.json_helpers import json_dumps
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs
from letta.schemas.letta_message import (
ApprovalRequestMessage,
ApprovalResponseMessage,
@@ -30,6 +31,35 @@ from letta.schemas.usage import LettaUsageStatistics
# TODO: consider moving into own file
class TurnTokenData(BaseModel):
"""Token data for a single LLM generation turn in a multi-turn agent interaction.
Used for RL training to track token IDs and logprobs across all LLM calls,
not just the final one. Tool results are included so the client can tokenize
them with loss_mask=0 (non-trainable).
"""
role: Literal["assistant", "tool"] = Field(
...,
description="Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)."
)
output_ids: Optional[List[int]] = Field(
None,
description="Token IDs from SGLang native endpoint. Only present for assistant turns."
)
output_token_logprobs: Optional[List[List[Any]]] = Field(
None,
description="Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns."
)
content: Optional[str] = Field(
None,
description="Text content. For tool turns, client tokenizes this with loss_mask=0."
)
tool_name: Optional[str] = Field(
None,
description="Name of the tool called. Only present for tool turns."
)
class LettaResponse(BaseModel):
"""
Response object from an agent interaction, consisting of the new messages generated by the agent and usage statistics.
@@ -57,6 +87,16 @@ class LettaResponse(BaseModel):
...,
description="The usage statistics of the agent.",
)
logprobs: Optional[ChoiceLogprobs] = Field(
None,
description="Log probabilities of the output tokens from the last LLM call. Only present if return_logprobs was enabled.",
)
turns: Optional[List[TurnTokenData]] = Field(
None,
description="Token data for all LLM generations in multi-turn agent interaction. "
"Includes token IDs and logprobs for each assistant turn, plus tool result content. "
"Only present if return_token_ids was enabled. Used for RL training with loss masking.",
)
def __str__(self):
return json_dumps(

View File

@@ -112,6 +112,19 @@ class LLMConfig(BaseModel):
False,
description="Enable strict mode for tool calling. When true, tool schemas include strict: true and additionalProperties: false, guaranteeing tool outputs match JSON schemas.",
)
return_logprobs: bool = Field(
False,
description="Whether to return log probabilities of the output tokens. Useful for RL training.",
)
top_logprobs: Optional[int] = Field(
None,
description="Number of most likely tokens to return at each position (0-20). Requires return_logprobs=True.",
)
return_token_ids: bool = Field(
False,
description="Whether to return token IDs for all LLM generations via SGLang native endpoint. "
"Required for multi-turn RL training with loss masking. Only works with SGLang provider.",
)
@model_validator(mode="before")
@classmethod

View File

@@ -1673,6 +1673,20 @@ async def send_message(
result = None
run_status = RunStatus.failed # Default to failed, updated on success
try:
# Handle request-level logprobs override
if request.return_logprobs or request.return_token_ids:
agent = agent.model_copy(
update={
"llm_config": agent.llm_config.model_copy(
update={
"return_logprobs": request.return_logprobs,
"top_logprobs": request.top_logprobs,
"return_token_ids": request.return_token_ids,
}
)
}
)
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
result = await agent_loop.step(
request.messages,

View File

@@ -1853,6 +1853,9 @@ async def test_agent_state_schema_unchanged(server: SyncServer):
"tier",
"parallel_tool_calls",
"strict",
"return_logprobs",
"top_logprobs",
"return_token_ids",
}
actual_llm_config_fields = set(llm_config_fields.keys())
if actual_llm_config_fields != expected_llm_config_fields: