* Add log probabilities support for RL training
This enables Letta server to request and return log probabilities from
OpenAI-compatible providers (including SGLang) for use in RL training.
Changes:
- LLMConfig: Add return_logprobs and top_logprobs fields
- OpenAIClient: Set logprobs in ChatCompletionRequest when enabled
- LettaLLMAdapter: Add logprobs field and extract from response
- LettaResponse: Add logprobs field to return log probs to client
- LettaRequest: Add return_logprobs/top_logprobs for per-request override
- LettaAgentV3: Store and pass logprobs through to response
- agents.py: Handle request-level logprobs override
Usage:
response = client.agents.messages.create(
agent_id=agent_id,
messages=[...],
return_logprobs=True,
top_logprobs=5,
)
print(response.logprobs) # Per-token log probabilities
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* Add multi-turn token tracking for RL training via SGLang native endpoint
- Add TurnTokenData schema to track token IDs and logprobs per turn
- Add return_token_ids flag to LettaRequest and LLMConfig
- Create SGLangNativeClient for /generate endpoint (returns output_ids)
- Create SGLangNativeAdapter that uses native endpoint
- Modify LettaAgentV3 to accumulate turns across LLM calls
- Include turns in LettaResponse when return_token_ids=True
* Fix: Add SGLang native adapter to step() method, not just stream()
* Fix: Handle Pydantic Message objects in SGLang native adapter
* Fix: Remove api_key reference from LLMConfig (not present)
* Fix: Add missing 'created' field to ChatCompletionResponse
* Add full tool support to SGLang native adapter
- Format tools into prompt in Qwen-style format
- Parse tool calls from <tool_call> tags in response
- Format tool results as <tool_response> in user messages
- Set finish_reason to 'tool_calls' when tools are called
* Use tokenizer.apply_chat_template for proper tool formatting
- Add tokenizer caching in SGLang native adapter
- Use apply_chat_template when tokenizer available
- Fall back to manual formatting if not
- Convert Letta messages to OpenAI format for tokenizer
* Fix: Use func_response instead of tool_return for ToolReturn content
* Fix: Get output_token_logprobs from meta_info in SGLang response
* Fix: Allow None in output_token_logprobs (SGLang format includes null)
* chore: remove unrelated files from logprobs branch
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* fix: add missing call_type param to adapter constructors in letta_agent_v3
The SGLang refactor dropped call_type=LLMCallType.agent_step when extracting
adapter creation into conditional blocks. Restores it for all 3 spots (SGLang
in step, SimpleLLM in step, SGLang in stream).
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* just stage-api && just publish-api
* fix: update expected LLMConfig fields in schema test for logprobs support
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* chore: remove rllm provider references
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* just stage-api && just publish-api
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-65-206.ec2.internal>
Co-authored-by: Letta <noreply@letta.com>
109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
"""
|
|
SGLang Native Client for Letta.
|
|
|
|
This client uses SGLang's native /generate endpoint instead of the OpenAI-compatible
|
|
/v1/chat/completions endpoint. The native endpoint returns token IDs and per-token
|
|
logprobs, which are essential for multi-turn RL training.
|
|
|
|
The OpenAI-compatible endpoint only returns token strings, not IDs, making it
|
|
impossible to accurately reconstruct the token sequence for training.
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SGLangNativeClient:
|
|
"""Client for SGLang's native /generate endpoint.
|
|
|
|
Unlike the OpenAI-compatible endpoint, this returns:
|
|
- output_ids: List of token IDs
|
|
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
|
|
|
This is essential for RL training where we need exact token IDs, not re-tokenized text.
|
|
"""
|
|
|
|
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
|
"""
|
|
Initialize the SGLang native client.
|
|
|
|
Args:
|
|
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
|
|
api_key: Optional API key for authentication
|
|
"""
|
|
# Remove /v1 suffix if present - native endpoint is at root
|
|
self.base_url = base_url.rstrip("/")
|
|
if self.base_url.endswith("/v1"):
|
|
self.base_url = self.base_url[:-3]
|
|
self.api_key = api_key
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
sampling_params: Optional[Dict[str, Any]] = None,
|
|
return_logprob: bool = True,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Call SGLang's native /generate endpoint.
|
|
|
|
Args:
|
|
text: The formatted prompt text (with chat template applied)
|
|
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
|
|
return_logprob: Whether to return logprobs (default True for RL training)
|
|
|
|
Returns:
|
|
Response dict with:
|
|
- text: Generated text
|
|
- output_ids: List of token IDs
|
|
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
|
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
|
|
|
|
Example response:
|
|
{
|
|
"text": "Hello! How can I help?",
|
|
"output_ids": [9707, 0, 2585, 646, 358, 1492, 30],
|
|
"output_token_logprobs": [
|
|
[-0.005, 9707, null],
|
|
[0.0, 0, null],
|
|
...
|
|
],
|
|
"meta_info": {
|
|
"finish_reason": {"type": "stop", "matched": 151645},
|
|
"prompt_tokens": 42,
|
|
...
|
|
}
|
|
}
|
|
"""
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
payload = {
|
|
"text": text,
|
|
"sampling_params": sampling_params or {},
|
|
"return_logprob": return_logprob,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/generate",
|
|
json=payload,
|
|
headers=headers,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check if the SGLang server is healthy."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
response = await client.get(f"{self.base_url}/health")
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|