* Add log probabilities support for RL training
This enables Letta server to request and return log probabilities from
OpenAI-compatible providers (including SGLang) for use in RL training.
Changes:
- LLMConfig: Add return_logprobs and top_logprobs fields
- OpenAIClient: Set logprobs in ChatCompletionRequest when enabled
- LettaLLMAdapter: Add logprobs field and extract from response
- LettaResponse: Add logprobs field to return log probs to client
- LettaRequest: Add return_logprobs/top_logprobs for per-request override
- LettaAgentV3: Store and pass logprobs through to response
- agents.py: Handle request-level logprobs override
Usage:
response = client.agents.messages.create(
agent_id=agent_id,
messages=[...],
return_logprobs=True,
top_logprobs=5,
)
print(response.logprobs) # Per-token log probabilities
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* Add multi-turn token tracking for RL training via SGLang native endpoint
- Add TurnTokenData schema to track token IDs and logprobs per turn
- Add return_token_ids flag to LettaRequest and LLMConfig
- Create SGLangNativeClient for /generate endpoint (returns output_ids)
- Create SGLangNativeAdapter that uses native endpoint
- Modify LettaAgentV3 to accumulate turns across LLM calls
- Include turns in LettaResponse when return_token_ids=True
* Fix: Add SGLang native adapter to step() method, not just stream()
* Fix: Handle Pydantic Message objects in SGLang native adapter
* Fix: Remove api_key reference from LLMConfig (not present)
* Fix: Add missing 'created' field to ChatCompletionResponse
* Add full tool support to SGLang native adapter
- Format tools into prompt in Qwen-style format
- Parse tool calls from <tool_call> tags in response
- Format tool results as <tool_response> in user messages
- Set finish_reason to 'tool_calls' when tools are called
* Use tokenizer.apply_chat_template for proper tool formatting
- Add tokenizer caching in SGLang native adapter
- Use apply_chat_template when tokenizer available
- Fall back to manual formatting if not
- Convert Letta messages to OpenAI format for tokenizer
* Fix: Use func_response instead of tool_return for ToolReturn content
* Fix: Get output_token_logprobs from meta_info in SGLang response
* Fix: Allow None in output_token_logprobs (SGLang format includes null)
* chore: remove unrelated files from logprobs branch
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* fix: add missing call_type param to adapter constructors in letta_agent_v3
The SGLang refactor dropped call_type=LLMCallType.agent_step when extracting
adapter creation into conditional blocks. Restores it for all 3 spots (SGLang
in step, SimpleLLM in step, SGLang in stream).
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* just stage-api && just publish-api
* fix: update expected LLMConfig fields in schema test for logprobs support
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* chore: remove rllm provider references
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
* just stage-api && just publish-api
🤖 Generated with [Letta Code](https://letta.com)
Co-Authored-By: Letta <noreply@letta.com>
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-65-206.ec2.internal>
Co-authored-by: Letta <noreply@letta.com>
212 lines
9.5 KiB
Python
212 lines
9.5 KiB
Python
import uuid
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
|
|
|
|
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
from letta.schemas.letta_message import MessageType
|
|
from letta.schemas.letta_message_content import LettaMessageContentUnion
|
|
from letta.schemas.message import MessageCreate, MessageCreateUnion, MessageRole
|
|
from letta.validators import AgentId
|
|
|
|
|
|
class ClientToolSchema(BaseModel):
|
|
"""Schema for a client-side tool passed in the request.
|
|
|
|
Client-side tools are executed by the client, not the server. When the agent
|
|
calls a client-side tool, execution pauses and returns control to the client
|
|
to execute the tool and provide the result.
|
|
"""
|
|
|
|
name: str = Field(..., description="The name of the tool function")
|
|
description: Optional[str] = Field(None, description="Description of what the tool does")
|
|
parameters: Optional[Dict[str, Any]] = Field(None, description="JSON Schema for the function parameters")
|
|
|
|
|
|
class LettaRequest(BaseModel):
|
|
messages: Optional[List[MessageCreateUnion]] = Field(None, description="The messages to be sent to the agent.")
|
|
input: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
|
|
None, description="Syntactic sugar for a single user message. Equivalent to messages=[{'role': 'user', 'content': input}]."
|
|
)
|
|
max_steps: int = Field(
|
|
default=DEFAULT_MAX_STEPS,
|
|
description="Maximum number of steps the agent should take to process the request.",
|
|
)
|
|
use_assistant_message: bool = Field(
|
|
default=True,
|
|
description="Whether the server should parse specific tool call arguments (default `send_message`) as `AssistantMessage` objects. Still supported for legacy agent types, but deprecated for letta_v1_agent onward.",
|
|
deprecated=True,
|
|
)
|
|
assistant_message_tool_name: str = Field(
|
|
default=DEFAULT_MESSAGE_TOOL,
|
|
description="The name of the designated message tool. Still supported for legacy agent types, but deprecated for letta_v1_agent onward.",
|
|
deprecated=True,
|
|
)
|
|
assistant_message_tool_kwarg: str = Field(
|
|
default=DEFAULT_MESSAGE_TOOL_KWARG,
|
|
description="The name of the message argument in the designated message tool. Still supported for legacy agent types, but deprecated for letta_v1_agent onward.",
|
|
deprecated=True,
|
|
)
|
|
|
|
# filter to only return specific message types
|
|
include_return_message_types: Optional[List[MessageType]] = Field(
|
|
default=None, description="Only return specified message types in the response. If `None` (default) returns all messages."
|
|
)
|
|
|
|
enable_thinking: str = Field(
|
|
default=True,
|
|
description="If set to True, enables reasoning before responses or tool calls from the agent.",
|
|
deprecated=True,
|
|
)
|
|
|
|
# Client-side tools
|
|
client_tools: Optional[List[ClientToolSchema]] = Field(
|
|
None,
|
|
description="Client-side tools that the agent can call. When the agent calls a client-side tool, "
|
|
"execution pauses and returns control to the client to execute the tool and provide the result via a ToolReturn.",
|
|
)
|
|
|
|
# Model override
|
|
override_model: Optional[str] = Field(
|
|
None,
|
|
description="Model handle to use for this request instead of the agent's default model. "
|
|
"This allows sending a message to a different model without changing the agent's configuration.",
|
|
)
|
|
|
|
# Compaction message format
|
|
include_compaction_messages: bool = Field(
|
|
default=False,
|
|
description="If True, compaction events emit structured `SummaryMessage` and `EventMessage` types. "
|
|
"If False (default), compaction messages are not included in the response.",
|
|
)
|
|
|
|
# Log probabilities for RL training
|
|
return_logprobs: bool = Field(
|
|
default=False,
|
|
description="If True, returns log probabilities of the output tokens in the response. "
|
|
"Useful for RL training. Only supported for OpenAI-compatible providers (including SGLang).",
|
|
)
|
|
top_logprobs: Optional[int] = Field(
|
|
default=None,
|
|
description="Number of most likely tokens to return at each position (0-20). "
|
|
"Requires return_logprobs=True.",
|
|
)
|
|
return_token_ids: bool = Field(
|
|
default=False,
|
|
description="If True, returns token IDs and logprobs for ALL LLM generations in the agent step, "
|
|
"not just the last one. Uses SGLang native /generate endpoint. "
|
|
"Returns 'turns' field with TurnTokenData for each assistant/tool turn. "
|
|
"Required for proper multi-turn RL training with loss masking.",
|
|
)
|
|
|
|
@field_validator("messages", mode="before")
|
|
@classmethod
|
|
def add_default_type_to_messages(cls, v):
|
|
"""Handle union without discriminator - default to 'message' type if not specified"""
|
|
if isinstance(v, list):
|
|
for item in v:
|
|
if isinstance(item, dict):
|
|
# If type is not present, determine based on fields
|
|
if "type" not in item:
|
|
# If it has approval-specific fields, it's an approval
|
|
if "approval_request_id" in item or "approve" in item:
|
|
item["type"] = "approval"
|
|
else:
|
|
# Default to message
|
|
item["type"] = "message"
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_input_or_messages(self):
|
|
"""Ensure exactly one of input or messages is set, and convert input to messages if needed"""
|
|
if self.input is not None and self.messages is not None:
|
|
raise ValueError("Cannot specify both 'input' and 'messages'. Use one or the other.")
|
|
if self.input is None and self.messages is None:
|
|
raise ValueError("Must specify either 'input' or 'messages'.")
|
|
|
|
# Convert input to messages format
|
|
# input can be either a string or List[LettaMessageContentUnion]
|
|
if self.input is not None:
|
|
# Both str and List[LettaMessageContentUnion] are valid content types for MessageCreate
|
|
self.messages = [MessageCreate(role=MessageRole.user, content=self.input, otid=str(uuid.uuid4()))]
|
|
|
|
return self
|
|
|
|
|
|
class LettaStreamingRequest(LettaRequest):
|
|
streaming: bool = Field(
|
|
default=False,
|
|
description="If True, returns a streaming response (Server-Sent Events). If False (default), returns a complete response.",
|
|
)
|
|
stream_tokens: bool = Field(
|
|
default=False,
|
|
description="Flag to determine if individual tokens should be streamed, rather than streaming per step (only used when streaming=true).",
|
|
)
|
|
include_pings: bool = Field(
|
|
default=True,
|
|
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts (only used when streaming=true).",
|
|
)
|
|
background: bool = Field(
|
|
default=False,
|
|
description="Whether to process the request in the background (only used when streaming=true).",
|
|
)
|
|
|
|
|
|
class ConversationMessageRequest(LettaRequest):
|
|
"""Request for sending messages to a conversation. Streams by default."""
|
|
|
|
streaming: bool = Field(
|
|
default=True,
|
|
description="If True (default), returns a streaming response (Server-Sent Events). If False, returns a complete JSON response.",
|
|
)
|
|
stream_tokens: bool = Field(
|
|
default=False,
|
|
description="Flag to determine if individual tokens should be streamed, rather than streaming per step (only used when streaming=true).",
|
|
)
|
|
include_pings: bool = Field(
|
|
default=True,
|
|
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts (only used when streaming=true).",
|
|
)
|
|
background: bool = Field(
|
|
default=False,
|
|
description="Whether to process the request in the background (only used when streaming=true).",
|
|
)
|
|
|
|
|
|
class LettaAsyncRequest(LettaRequest):
|
|
callback_url: Optional[str] = Field(None, description="Optional callback URL to POST to when the job completes")
|
|
|
|
|
|
class LettaBatchRequest(LettaRequest):
|
|
agent_id: AgentId = Field(..., description="The ID of the agent to send this batch request for")
|
|
|
|
|
|
class CreateBatch(BaseModel):
|
|
requests: List[LettaBatchRequest] = Field(..., description="List of requests to be processed in batch.")
|
|
callback_url: Optional[HttpUrl] = Field(
|
|
None,
|
|
description="Optional URL to call via POST when the batch completes. The callback payload will be a JSON object with the following fields: "
|
|
"{'job_id': string, 'status': string, 'completed_at': string}. "
|
|
"Where 'job_id' is the unique batch job identifier, "
|
|
"'status' is the final batch status (e.g., 'completed', 'failed'), and "
|
|
"'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.",
|
|
)
|
|
|
|
|
|
class RetrieveStreamRequest(BaseModel):
|
|
starting_after: int = Field(
|
|
0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id"
|
|
)
|
|
include_pings: Optional[bool] = Field(
|
|
default=True,
|
|
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
|
|
)
|
|
poll_interval: Optional[float] = Field(
|
|
default=0.1,
|
|
description="Seconds to wait between polls when no new data.",
|
|
)
|
|
batch_size: Optional[int] = Field(
|
|
default=100,
|
|
description="Number of entries to read per batch.",
|
|
)
|