Files
letta-server/letta/llm_api/sglang_native_client.py
Kevin Lin 23c94ec6d3 feat: add log probabilities from OpenAI-compatible servers and SGLang native endpoint (#9240)
* Add log probabilities support for RL training

This enables Letta server to request and return log probabilities from
OpenAI-compatible providers (including SGLang) for use in RL training.

Changes:
- LLMConfig: Add return_logprobs and top_logprobs fields
- OpenAIClient: Set logprobs in ChatCompletionRequest when enabled
- LettaLLMAdapter: Add logprobs field and extract from response
- LettaResponse: Add logprobs field to return log probs to client
- LettaRequest: Add return_logprobs/top_logprobs for per-request override
- LettaAgentV3: Store and pass logprobs through to response
- agents.py: Handle request-level logprobs override

Usage:
  response = client.agents.messages.create(
      agent_id=agent_id,
      messages=[...],
      return_logprobs=True,
      top_logprobs=5,
  )
  print(response.logprobs)  # Per-token log probabilities

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* Add multi-turn token tracking for RL training via SGLang native endpoint

- Add TurnTokenData schema to track token IDs and logprobs per turn
- Add return_token_ids flag to LettaRequest and LLMConfig
- Create SGLangNativeClient for /generate endpoint (returns output_ids)
- Create SGLangNativeAdapter that uses native endpoint
- Modify LettaAgentV3 to accumulate turns across LLM calls
- Include turns in LettaResponse when return_token_ids=True

* Fix: Add SGLang native adapter to step() method, not just stream()

* Fix: Handle Pydantic Message objects in SGLang native adapter

* Fix: Remove api_key reference from LLMConfig (not present)

* Fix: Add missing 'created' field to ChatCompletionResponse

* Add full tool support to SGLang native adapter

- Format tools into prompt in Qwen-style format
- Parse tool calls from <tool_call> tags in response
- Format tool results as <tool_response> in user messages
- Set finish_reason to 'tool_calls' when tools are called

* Use tokenizer.apply_chat_template for proper tool formatting

- Add tokenizer caching in SGLang native adapter
- Use apply_chat_template when tokenizer available
- Fall back to manual formatting if not
- Convert Letta messages to OpenAI format for tokenizer

* Fix: Use func_response instead of tool_return for ToolReturn content

* Fix: Get output_token_logprobs from meta_info in SGLang response

* Fix: Allow None in output_token_logprobs (SGLang format includes null)

* chore: remove unrelated files from logprobs branch

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: add missing call_type param to adapter constructors in letta_agent_v3

The SGLang refactor dropped call_type=LLMCallType.agent_step when extracting
adapter creation into conditional blocks. Restores it for all 3 spots (SGLang
in step, SimpleLLM in step, SGLang in stream).

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* just stage-api && just publish-api

* fix: update expected LLMConfig fields in schema test for logprobs support

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* chore: remove rllm provider references

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* just stage-api && just publish-api

🤖 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-65-206.ec2.internal>
Co-authored-by: Letta <noreply@letta.com>
2026-02-24 10:52:07 -08:00

109 lines
3.7 KiB
Python

"""
SGLang Native Client for Letta.
This client uses SGLang's native /generate endpoint instead of the OpenAI-compatible
/v1/chat/completions endpoint. The native endpoint returns token IDs and per-token
logprobs, which are essential for multi-turn RL training.
The OpenAI-compatible endpoint only returns token strings, not IDs, making it
impossible to accurately reconstruct the token sequence for training.
"""
from typing import Any, Dict, List, Optional
import httpx
from letta.log import get_logger
logger = get_logger(__name__)
class SGLangNativeClient:
"""Client for SGLang's native /generate endpoint.
Unlike the OpenAI-compatible endpoint, this returns:
- output_ids: List of token IDs
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
This is essential for RL training where we need exact token IDs, not re-tokenized text.
"""
def __init__(self, base_url: str, api_key: Optional[str] = None):
"""
Initialize the SGLang native client.
Args:
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
api_key: Optional API key for authentication
"""
# Remove /v1 suffix if present - native endpoint is at root
self.base_url = base_url.rstrip("/")
if self.base_url.endswith("/v1"):
self.base_url = self.base_url[:-3]
self.api_key = api_key
async def generate(
self,
text: str,
sampling_params: Optional[Dict[str, Any]] = None,
return_logprob: bool = True,
) -> Dict[str, Any]:
"""
Call SGLang's native /generate endpoint.
Args:
text: The formatted prompt text (with chat template applied)
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
return_logprob: Whether to return logprobs (default True for RL training)
Returns:
Response dict with:
- text: Generated text
- output_ids: List of token IDs
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
Example response:
{
"text": "Hello! How can I help?",
"output_ids": [9707, 0, 2585, 646, 358, 1492, 30],
"output_token_logprobs": [
[-0.005, 9707, null],
[0.0, 0, null],
...
],
"meta_info": {
"finish_reason": {"type": "stop", "matched": 151645},
"prompt_tokens": 42,
...
}
}
"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"text": text,
"sampling_params": sampling_params or {},
"return_logprob": return_logprob,
}
async with httpx.AsyncClient(timeout=300.0) as client:
response = await client.post(
f"{self.base_url}/generate",
json=payload,
headers=headers,
)
response.raise_for_status()
return response.json()
async def health_check(self) -> bool:
"""Check if the SGLang server is healthy."""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{self.base_url}/health")
return response.status_code == 200
except Exception:
return False