feat: add log probabilities from OpenAI-compatible servers and SGLang native endpoint (#9240)

* Add log probabilities support for RL training

This enables Letta server to request and return log probabilities from
OpenAI-compatible providers (including SGLang) for use in RL training.

Changes:
- LLMConfig: Add return_logprobs and top_logprobs fields
- OpenAIClient: Set logprobs in ChatCompletionRequest when enabled
- LettaLLMAdapter: Add logprobs field and extract from response
- LettaResponse: Add logprobs field to return log probs to client
- LettaRequest: Add return_logprobs/top_logprobs for per-request override
- LettaAgentV3: Store and pass logprobs through to response
- agents.py: Handle request-level logprobs override

Usage:
  response = client.agents.messages.create(
      agent_id=agent_id,
      messages=[...],
      return_logprobs=True,
      top_logprobs=5,
  )
  print(response.logprobs)  # Per-token log probabilities

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* Add multi-turn token tracking for RL training via SGLang native endpoint

- Add TurnTokenData schema to track token IDs and logprobs per turn
- Add return_token_ids flag to LettaRequest and LLMConfig
- Create SGLangNativeClient for /generate endpoint (returns output_ids)
- Create SGLangNativeAdapter that uses native endpoint
- Modify LettaAgentV3 to accumulate turns across LLM calls
- Include turns in LettaResponse when return_token_ids=True

* Fix: Add SGLang native adapter to step() method, not just stream()

* Fix: Handle Pydantic Message objects in SGLang native adapter

* Fix: Remove api_key reference from LLMConfig (not present)

* Fix: Add missing 'created' field to ChatCompletionResponse

* Add full tool support to SGLang native adapter

- Format tools into prompt in Qwen-style format
- Parse tool calls from <tool_call> tags in response
- Format tool results as <tool_response> in user messages
- Set finish_reason to 'tool_calls' when tools are called

* Use tokenizer.apply_chat_template for proper tool formatting

- Add tokenizer caching in SGLang native adapter
- Use apply_chat_template when tokenizer available
- Fall back to manual formatting if not
- Convert Letta messages to OpenAI format for tokenizer

* Fix: Use func_response instead of tool_return for ToolReturn content

* Fix: Get output_token_logprobs from meta_info in SGLang response

* Fix: Allow None in output_token_logprobs (SGLang format includes null)

* chore: remove unrelated files from logprobs branch

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: add missing call_type param to adapter constructors in letta_agent_v3

The SGLang refactor dropped call_type=LLMCallType.agent_step when extracting
adapter creation into conditional blocks. Restores it for all 3 spots (SGLang
in step, SimpleLLM in step, SGLang in stream).

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* just stage-api && just publish-api

* fix: update expected LLMConfig fields in schema test for logprobs support

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* chore: remove rllm provider references

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* just stage-api && just publish-api

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-65-206.ec2.internal>
Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Kevin Lin
2026-02-10 07:12:38 -08:00
committed by Caren Thomas
parent f9f1c55c93
commit 23c94ec6d3
13 changed files with 1305 additions and 103 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Literal, Optional
from opentelemetry.trace import Span
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
from letta.adapters.sglang_native_adapter import SGLangNativeAdapter
from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
from letta.agents.helpers import (
@@ -41,11 +42,11 @@ from letta.schemas.letta_message import (
)
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_response import LettaResponse, TurnTokenData
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -121,6 +122,11 @@ class LettaAgentV3(LettaAgentV2):
self.conversation_id: str | None = None
# Client-side tools passed in the request (executed by client, not server)
self.client_tools: list[ClientToolSchema] = []
# Log probabilities from the most recent LLM call (for RL training)
self.logprobs: ChoiceLogprobs | None = None
# Multi-turn token tracking for RL training (accumulated across all LLM calls)
self.turns: list[TurnTokenData] = []
self.return_token_ids: bool = False
def _compute_tool_return_truncation_chars(self) -> int:
"""Compute a dynamic cap for tool returns in requests.
@@ -196,6 +202,41 @@ class LettaAgentV3(LettaAgentV2):
input_messages_to_persist = [input_messages_to_persist[0]]
self.in_context_messages = curr_in_context_messages
# Check if we should use SGLang native adapter for multi-turn RL training
use_sglang_native = (
self.agent_state.llm_config.return_token_ids
and self.agent_state.llm_config.handle
and self.agent_state.llm_config.handle.startswith("sglang/")
)
self.return_token_ids = use_sglang_native
if use_sglang_native:
# Use SGLang native adapter for multi-turn RL training
llm_adapter = SGLangNativeAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
# Reset turns tracking for this step
self.turns = []
else:
llm_adapter = SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
for i in range(max_steps):
if i == 1 and follow_up_messages:
input_messages_to_persist = follow_up_messages
@@ -205,17 +246,7 @@ class LettaAgentV3(LettaAgentV2):
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
messages=list(self.in_context_messages + input_messages_to_persist),
input_messages_to_persist=input_messages_to_persist,
# TODO need to support non-streaming adapter too
llm_adapter=SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
llm_adapter=llm_adapter,
run_id=run_id,
# use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
@@ -292,7 +323,13 @@ class LettaAgentV3(LettaAgentV2):
response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types]
# Set context_tokens to expose actual context window usage (vs accumulated prompt_tokens)
self.usage.context_tokens = self.context_token_estimate
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
result = LettaResponse(
messages=response_letta_messages,
stop_reason=self.stop_reason,
usage=self.usage,
logprobs=self.logprobs,
turns=self.turns if self.return_token_ids and self.turns else None,
)
if run_id:
if self.job_update_metadata is None:
self.job_update_metadata = {}
@@ -355,6 +392,14 @@ class LettaAgentV3(LettaAgentV2):
actor=self.actor,
)
# Check if we should use SGLang native adapter for multi-turn RL training
use_sglang_native = (
self.agent_state.llm_config.return_token_ids
and self.agent_state.llm_config.handle
and self.agent_state.llm_config.handle.startswith("sglang/")
)
self.return_token_ids = use_sglang_native
if stream_tokens:
llm_adapter = SimpleLLMStreamAdapter(
llm_client=self.llm_client,
@@ -366,6 +411,20 @@ class LettaAgentV3(LettaAgentV2):
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
elif use_sglang_native:
# Use SGLang native adapter for multi-turn RL training
llm_adapter = SGLangNativeAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
# Reset turns tracking for this step
self.turns = []
else:
llm_adapter = SimpleLLMRequestAdapter(
llm_client=self.llm_client,
@@ -488,7 +547,13 @@ class LettaAgentV3(LettaAgentV2):
if run_id:
# Filter out LettaStopReason from messages (only valid in LettaStreamingResponse, not LettaResponse)
filtered_messages = [m for m in response_letta_messages if not isinstance(m, LettaStopReason)]
result = LettaResponse(messages=filtered_messages, stop_reason=self.stop_reason, usage=self.usage)
result = LettaResponse(
messages=filtered_messages,
stop_reason=self.stop_reason,
usage=self.usage,
logprobs=self.logprobs,
turns=self.turns if self.return_token_ids and self.turns else None,
)
if self.job_update_metadata is None:
self.job_update_metadata = {}
self.job_update_metadata["result"] = result.model_dump(mode="json")
@@ -970,6 +1035,19 @@ class LettaAgentV3(LettaAgentV2):
self.context_token_estimate = llm_adapter.usage.total_tokens
self.logger.info(f"Context token estimate after LLM request: {self.context_token_estimate}")
# Extract logprobs if present (for RL training)
if llm_adapter.logprobs is not None:
self.logprobs = llm_adapter.logprobs
# Track turn data for multi-turn RL training (SGLang native mode)
if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids:
self.turns.append(TurnTokenData(
role="assistant",
output_ids=llm_adapter.output_ids,
output_token_logprobs=llm_adapter.output_token_logprobs,
content=llm_adapter.chat_completions_response.choices[0].message.content if llm_adapter.chat_completions_response else None,
))
# Handle the AI response with the extracted data (supports multiple tool calls)
# Gather tool calls - check for multi-call API first, then fall back to single
if hasattr(llm_adapter, "tool_calls") and llm_adapter.tool_calls:
@@ -1015,6 +1093,34 @@ class LettaAgentV3(LettaAgentV2):
# extend trackers with new messages
self.response_messages.extend(new_messages)
messages.extend(new_messages)
# Track tool return turns for multi-turn RL training
if self.return_token_ids:
for msg in new_messages:
if msg.role == "tool":
# Get tool return content
tool_content = None
tool_name = None
if hasattr(msg, "tool_returns") and msg.tool_returns:
# Aggregate all tool returns into content (func_response is the actual content)
parts = []
for tr in msg.tool_returns:
if hasattr(tr, 'func_response') and tr.func_response:
if isinstance(tr.func_response, str):
parts.append(tr.func_response)
else:
parts.append(str(tr.func_response))
tool_content = "\n".join(parts)
elif hasattr(msg, "content") and msg.content:
tool_content = msg.content if isinstance(msg.content, str) else str(msg.content)
if hasattr(msg, "name"):
tool_name = msg.name
if tool_content:
self.turns.append(TurnTokenData(
role="tool",
content=tool_content,
tool_name=tool_name,
))
# step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics
# persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state