feat: add log probabilities from OpenAI-compatible servers and SGLang native endpoint (#9240)
* Add log probabilities support for RL training
This enables Letta server to request and return log probabilities from
OpenAI-compatible providers (including SGLang) for use in RL training.
Changes:
- LLMConfig: Add return_logprobs and top_logprobs fields
- OpenAIClient: Set logprobs in ChatCompletionRequest when enabled
- LettaLLMAdapter: Add logprobs field and extract from response
- LettaResponse: Add logprobs field to return log probs to client
- LettaRequest: Add return_logprobs/top_logprobs for per-request override
- LettaAgentV3: Store and pass logprobs through to response
- agents.py: Handle request-level logprobs override
Usage:
response = client.agents.messages.create(
agent_id=agent_id,
messages=[...],
return_logprobs=True,
top_logprobs=5,
)
print(response.logprobs) # Per-token log probabilities
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* Add multi-turn token tracking for RL training via SGLang native endpoint
- Add TurnTokenData schema to track token IDs and logprobs per turn
- Add return_token_ids flag to LettaRequest and LLMConfig
- Create SGLangNativeClient for /generate endpoint (returns output_ids)
- Create SGLangNativeAdapter that uses native endpoint
- Modify LettaAgentV3 to accumulate turns across LLM calls
- Include turns in LettaResponse when return_token_ids=True
* Fix: Add SGLang native adapter to step() method, not just stream()
* Fix: Handle Pydantic Message objects in SGLang native adapter
* Fix: Remove api_key reference from LLMConfig (not present)
* Fix: Add missing 'created' field to ChatCompletionResponse
* Add full tool support to SGLang native adapter
- Format tools into prompt in Qwen-style format
- Parse tool calls from <tool_call> tags in response
- Format tool results as <tool_response> in user messages
- Set finish_reason to 'tool_calls' when tools are called
* Use tokenizer.apply_chat_template for proper tool formatting
- Add tokenizer caching in SGLang native adapter
- Use apply_chat_template when tokenizer available
- Fall back to manual formatting if not
- Convert Letta messages to OpenAI format for tokenizer
* Fix: Use func_response instead of tool_return for ToolReturn content
* Fix: Get output_token_logprobs from meta_info in SGLang response
* Fix: Allow None in output_token_logprobs (SGLang format includes null)
* chore: remove unrelated files from logprobs branch
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* fix: add missing call_type param to adapter constructors in letta_agent_v3
The SGLang refactor dropped call_type=LLMCallType.agent_step when extracting
adapter creation into conditional blocks. Restores it for all 3 spots (SGLang
in step, SimpleLLM in step, SGLang in stream).
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* just stage-api && just publish-api
* fix: update expected LLMConfig fields in schema test for logprobs support
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* chore: remove rllm provider references
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* just stage-api && just publish-api
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-65-206.ec2.internal>
Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
||||
from letta.adapters.sglang_native_adapter import SGLangNativeAdapter
|
||||
from letta.adapters.simple_llm_request_adapter import SimpleLLMRequestAdapter
|
||||
from letta.adapters.simple_llm_stream_adapter import SimpleLLMStreamAdapter
|
||||
from letta.agents.helpers import (
|
||||
@@ -41,11 +42,11 @@ from letta.schemas.letta_message import (
|
||||
)
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_response import LettaResponse, TurnTokenData
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
|
||||
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
|
||||
from letta.schemas.step import StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -121,6 +122,11 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.conversation_id: str | None = None
|
||||
# Client-side tools passed in the request (executed by client, not server)
|
||||
self.client_tools: list[ClientToolSchema] = []
|
||||
# Log probabilities from the most recent LLM call (for RL training)
|
||||
self.logprobs: ChoiceLogprobs | None = None
|
||||
# Multi-turn token tracking for RL training (accumulated across all LLM calls)
|
||||
self.turns: list[TurnTokenData] = []
|
||||
self.return_token_ids: bool = False
|
||||
|
||||
def _compute_tool_return_truncation_chars(self) -> int:
|
||||
"""Compute a dynamic cap for tool returns in requests.
|
||||
@@ -196,6 +202,41 @@ class LettaAgentV3(LettaAgentV2):
|
||||
input_messages_to_persist = [input_messages_to_persist[0]]
|
||||
|
||||
self.in_context_messages = curr_in_context_messages
|
||||
|
||||
# Check if we should use SGLang native adapter for multi-turn RL training
|
||||
use_sglang_native = (
|
||||
self.agent_state.llm_config.return_token_ids
|
||||
and self.agent_state.llm_config.handle
|
||||
and self.agent_state.llm_config.handle.startswith("sglang/")
|
||||
)
|
||||
self.return_token_ids = use_sglang_native
|
||||
|
||||
if use_sglang_native:
|
||||
# Use SGLang native adapter for multi-turn RL training
|
||||
llm_adapter = SGLangNativeAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
# Reset turns tracking for this step
|
||||
self.turns = []
|
||||
else:
|
||||
llm_adapter = SimpleLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
|
||||
for i in range(max_steps):
|
||||
if i == 1 and follow_up_messages:
|
||||
input_messages_to_persist = follow_up_messages
|
||||
@@ -205,17 +246,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
|
||||
messages=list(self.in_context_messages + input_messages_to_persist),
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
# TODO need to support non-streaming adapter too
|
||||
llm_adapter=SimpleLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
llm_adapter=llm_adapter,
|
||||
run_id=run_id,
|
||||
# use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
@@ -292,7 +323,13 @@ class LettaAgentV3(LettaAgentV2):
|
||||
response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types]
|
||||
# Set context_tokens to expose actual context window usage (vs accumulated prompt_tokens)
|
||||
self.usage.context_tokens = self.context_token_estimate
|
||||
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
result = LettaResponse(
|
||||
messages=response_letta_messages,
|
||||
stop_reason=self.stop_reason,
|
||||
usage=self.usage,
|
||||
logprobs=self.logprobs,
|
||||
turns=self.turns if self.return_token_ids and self.turns else None,
|
||||
)
|
||||
if run_id:
|
||||
if self.job_update_metadata is None:
|
||||
self.job_update_metadata = {}
|
||||
@@ -355,6 +392,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# Check if we should use SGLang native adapter for multi-turn RL training
|
||||
use_sglang_native = (
|
||||
self.agent_state.llm_config.return_token_ids
|
||||
and self.agent_state.llm_config.handle
|
||||
and self.agent_state.llm_config.handle.startswith("sglang/")
|
||||
)
|
||||
self.return_token_ids = use_sglang_native
|
||||
|
||||
if stream_tokens:
|
||||
llm_adapter = SimpleLLMStreamAdapter(
|
||||
llm_client=self.llm_client,
|
||||
@@ -366,6 +411,20 @@ class LettaAgentV3(LettaAgentV2):
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
elif use_sglang_native:
|
||||
# Use SGLang native adapter for multi-turn RL training
|
||||
llm_adapter = SGLangNativeAdapter(
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
call_type=LLMCallType.agent_step,
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
# Reset turns tracking for this step
|
||||
self.turns = []
|
||||
else:
|
||||
llm_adapter = SimpleLLMRequestAdapter(
|
||||
llm_client=self.llm_client,
|
||||
@@ -488,7 +547,13 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if run_id:
|
||||
# Filter out LettaStopReason from messages (only valid in LettaStreamingResponse, not LettaResponse)
|
||||
filtered_messages = [m for m in response_letta_messages if not isinstance(m, LettaStopReason)]
|
||||
result = LettaResponse(messages=filtered_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
result = LettaResponse(
|
||||
messages=filtered_messages,
|
||||
stop_reason=self.stop_reason,
|
||||
usage=self.usage,
|
||||
logprobs=self.logprobs,
|
||||
turns=self.turns if self.return_token_ids and self.turns else None,
|
||||
)
|
||||
if self.job_update_metadata is None:
|
||||
self.job_update_metadata = {}
|
||||
self.job_update_metadata["result"] = result.model_dump(mode="json")
|
||||
@@ -970,6 +1035,19 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.context_token_estimate = llm_adapter.usage.total_tokens
|
||||
self.logger.info(f"Context token estimate after LLM request: {self.context_token_estimate}")
|
||||
|
||||
# Extract logprobs if present (for RL training)
|
||||
if llm_adapter.logprobs is not None:
|
||||
self.logprobs = llm_adapter.logprobs
|
||||
|
||||
# Track turn data for multi-turn RL training (SGLang native mode)
|
||||
if self.return_token_ids and hasattr(llm_adapter, "output_ids") and llm_adapter.output_ids:
|
||||
self.turns.append(TurnTokenData(
|
||||
role="assistant",
|
||||
output_ids=llm_adapter.output_ids,
|
||||
output_token_logprobs=llm_adapter.output_token_logprobs,
|
||||
content=llm_adapter.chat_completions_response.choices[0].message.content if llm_adapter.chat_completions_response else None,
|
||||
))
|
||||
|
||||
# Handle the AI response with the extracted data (supports multiple tool calls)
|
||||
# Gather tool calls - check for multi-call API first, then fall back to single
|
||||
if hasattr(llm_adapter, "tool_calls") and llm_adapter.tool_calls:
|
||||
@@ -1015,6 +1093,34 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# extend trackers with new messages
|
||||
self.response_messages.extend(new_messages)
|
||||
messages.extend(new_messages)
|
||||
|
||||
# Track tool return turns for multi-turn RL training
|
||||
if self.return_token_ids:
|
||||
for msg in new_messages:
|
||||
if msg.role == "tool":
|
||||
# Get tool return content
|
||||
tool_content = None
|
||||
tool_name = None
|
||||
if hasattr(msg, "tool_returns") and msg.tool_returns:
|
||||
# Aggregate all tool returns into content (func_response is the actual content)
|
||||
parts = []
|
||||
for tr in msg.tool_returns:
|
||||
if hasattr(tr, 'func_response') and tr.func_response:
|
||||
if isinstance(tr.func_response, str):
|
||||
parts.append(tr.func_response)
|
||||
else:
|
||||
parts.append(str(tr.func_response))
|
||||
tool_content = "\n".join(parts)
|
||||
elif hasattr(msg, "content") and msg.content:
|
||||
tool_content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
if hasattr(msg, "name"):
|
||||
tool_name = msg.name
|
||||
if tool_content:
|
||||
self.turns.append(TurnTokenData(
|
||||
role="tool",
|
||||
content=tool_content,
|
||||
tool_name=tool_name,
|
||||
))
|
||||
|
||||
# step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics
|
||||
# persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state
|
||||
|
||||
Reference in New Issue
Block a user