chore: enable F821, F401, W293 (#9503)
* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
This commit is contained in:
@@ -19,18 +19,17 @@ from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.llm_api.sglang_native_client import SGLangNativeClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
Choice,
|
||||
ChoiceLogprobs,
|
||||
ChatCompletionTokenLogprob,
|
||||
FunctionCall,
|
||||
Message as ChoiceMessage,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -41,37 +40,38 @@ _tokenizer_cache: dict[str, Any] = {}
|
||||
class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
"""
|
||||
Adapter that uses SGLang's native /generate endpoint for multi-turn RL training.
|
||||
|
||||
|
||||
Key differences from SimpleLLMRequestAdapter:
|
||||
- Uses /generate instead of /v1/chat/completions
|
||||
- Returns output_ids (token IDs) in addition to text
|
||||
- Returns output_token_logprobs with [logprob, token_id] pairs
|
||||
- Formats tools into prompt and parses tool calls from response
|
||||
|
||||
|
||||
These are essential for building accurate loss masks in multi-turn training.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._sglang_client: Optional[SGLangNativeClient] = None
|
||||
self._tokenizer: Any = None
|
||||
|
||||
|
||||
def _get_tokenizer(self) -> Any:
|
||||
"""Get or create tokenizer for the model."""
|
||||
global _tokenizer_cache
|
||||
|
||||
|
||||
# Get model name from llm_config
|
||||
model_name = self.llm_config.model
|
||||
if not model_name:
|
||||
logger.warning("No model name in llm_config, cannot load tokenizer")
|
||||
return None
|
||||
|
||||
|
||||
# Check cache
|
||||
if model_name in _tokenizer_cache:
|
||||
return _tokenizer_cache[model_name]
|
||||
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger.info(f"Loading tokenizer for model: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
_tokenizer_cache[model_name] = tokenizer
|
||||
@@ -82,7 +82,7 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting")
|
||||
return None
|
||||
|
||||
|
||||
def _get_sglang_client(self) -> SGLangNativeClient:
|
||||
"""Get or create SGLang native client."""
|
||||
if self._sglang_client is None:
|
||||
@@ -94,17 +94,17 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
api_key=None,
|
||||
)
|
||||
return self._sglang_client
|
||||
|
||||
|
||||
def _format_tools_for_prompt(self, tools: list) -> str:
|
||||
"""
|
||||
Format tools in Qwen3 chat template format for the system prompt.
|
||||
|
||||
|
||||
This matches the exact format produced by Qwen3's tokenizer.apply_chat_template()
|
||||
with tools parameter.
|
||||
"""
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
|
||||
# Format each tool as JSON (matching Qwen3 template exactly)
|
||||
tool_jsons = []
|
||||
for tool in tools:
|
||||
@@ -120,84 +120,85 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
"name": getattr(getattr(tool, "function", tool), "name", ""),
|
||||
"description": getattr(getattr(tool, "function", tool), "description", ""),
|
||||
"parameters": getattr(getattr(tool, "function", tool), "parameters", {}),
|
||||
}
|
||||
},
|
||||
}
|
||||
tool_jsons.append(json.dumps(tool_dict))
|
||||
|
||||
|
||||
# Use exact Qwen3 format
|
||||
tools_section = (
|
||||
"\n\n# Tools\n\n"
|
||||
"You may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n"
|
||||
"<tools>\n"
|
||||
+ "\n".join(tool_jsons) + "\n"
|
||||
"<tools>\n" + "\n".join(tool_jsons) + "\n"
|
||||
"</tools>\n\n"
|
||||
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
|
||||
"<tool_call>\n"
|
||||
'{"name": <function-name>, "arguments": <args-json-object>}\n'
|
||||
"</tool_call>"
|
||||
)
|
||||
|
||||
|
||||
return tools_section
|
||||
|
||||
|
||||
def _convert_messages_to_openai_format(self, messages: list) -> list[dict]:
|
||||
"""Convert Letta Message objects to OpenAI-style message dicts."""
|
||||
openai_messages = []
|
||||
|
||||
|
||||
for msg in messages:
|
||||
# Handle both dict and Pydantic Message objects
|
||||
if hasattr(msg, 'role'):
|
||||
if hasattr(msg, "role"):
|
||||
role = msg.role
|
||||
content = msg.content if hasattr(msg, 'content') else ""
|
||||
content = msg.content if hasattr(msg, "content") else ""
|
||||
# Handle content that might be a list of content parts
|
||||
if isinstance(content, list):
|
||||
content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content])
|
||||
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
|
||||
elif content is None:
|
||||
content = ""
|
||||
tool_calls = getattr(msg, 'tool_calls', None)
|
||||
tool_call_id = getattr(msg, 'tool_call_id', None)
|
||||
name = getattr(msg, 'name', None)
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
tool_call_id = getattr(msg, "tool_call_id", None)
|
||||
name = getattr(msg, "name", None)
|
||||
else:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
tool_calls = msg.get("tool_calls", None)
|
||||
tool_call_id = msg.get("tool_call_id", None)
|
||||
name = msg.get("name", None)
|
||||
|
||||
|
||||
openai_msg = {"role": role, "content": content}
|
||||
|
||||
|
||||
if tool_calls:
|
||||
# Convert tool calls to OpenAI format
|
||||
openai_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
if hasattr(tc, 'function'):
|
||||
if hasattr(tc, "function"):
|
||||
tc_dict = {
|
||||
"id": getattr(tc, 'id', f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"id": getattr(tc, "id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments if isinstance(tc.function.arguments, str) else json.dumps(tc.function.arguments)
|
||||
}
|
||||
"arguments": tc.function.arguments
|
||||
if isinstance(tc.function.arguments, str)
|
||||
else json.dumps(tc.function.arguments),
|
||||
},
|
||||
}
|
||||
else:
|
||||
tc_dict = {
|
||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": tc.get("function", {})
|
||||
"function": tc.get("function", {}),
|
||||
}
|
||||
openai_tool_calls.append(tc_dict)
|
||||
openai_msg["tool_calls"] = openai_tool_calls
|
||||
|
||||
|
||||
if tool_call_id:
|
||||
openai_msg["tool_call_id"] = tool_call_id
|
||||
|
||||
|
||||
if name and role == "tool":
|
||||
openai_msg["name"] = name
|
||||
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
|
||||
return openai_messages
|
||||
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools: list) -> list[dict]:
|
||||
"""Convert tools to OpenAI format for tokenizer."""
|
||||
openai_tools = []
|
||||
@@ -218,24 +219,24 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
"name": getattr(func, "name", ""),
|
||||
"description": getattr(func, "description", ""),
|
||||
"parameters": getattr(func, "parameters", {}),
|
||||
}
|
||||
},
|
||||
}
|
||||
openai_tools.append(tool_dict)
|
||||
return openai_tools
|
||||
|
||||
|
||||
def _format_messages_to_text(self, messages: list, tools: list) -> str:
|
||||
"""
|
||||
Format messages to text using tokenizer's apply_chat_template if available.
|
||||
|
||||
|
||||
Falls back to manual formatting if tokenizer is not available.
|
||||
"""
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
|
||||
if tokenizer is not None:
|
||||
# Use tokenizer's apply_chat_template for proper formatting
|
||||
openai_messages = self._convert_messages_to_openai_format(messages)
|
||||
openai_tools = self._convert_tools_to_openai_format(tools) if tools else None
|
||||
|
||||
|
||||
try:
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
openai_messages,
|
||||
@@ -247,30 +248,30 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
return formatted
|
||||
except Exception as e:
|
||||
logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting")
|
||||
|
||||
|
||||
# Fallback to manual formatting
|
||||
return self._format_messages_to_text_manual(messages, tools)
|
||||
|
||||
|
||||
def _format_messages_to_text_manual(self, messages: list, tools: list) -> str:
|
||||
"""Manual fallback formatting for when tokenizer is not available."""
|
||||
formatted_parts = []
|
||||
tools_section = self._format_tools_for_prompt(tools)
|
||||
|
||||
|
||||
for msg in messages:
|
||||
# Handle both dict and Pydantic Message objects
|
||||
if hasattr(msg, 'role'):
|
||||
if hasattr(msg, "role"):
|
||||
role = msg.role
|
||||
content = msg.content if hasattr(msg, 'content') else ""
|
||||
content = msg.content if hasattr(msg, "content") else ""
|
||||
if isinstance(content, list):
|
||||
content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content])
|
||||
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
|
||||
elif content is None:
|
||||
content = ""
|
||||
tool_calls = getattr(msg, 'tool_calls', None)
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
else:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
tool_calls = msg.get("tool_calls", None)
|
||||
|
||||
|
||||
if role == "system":
|
||||
system_content = content + tools_section if tools_section else content
|
||||
formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>")
|
||||
@@ -281,62 +282,55 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
for tc in tool_calls:
|
||||
if hasattr(tc, 'function'):
|
||||
if hasattr(tc, "function"):
|
||||
tc_name = tc.function.name
|
||||
tc_args = tc.function.arguments
|
||||
else:
|
||||
tc_name = tc.get("function", {}).get("name", "")
|
||||
tc_args = tc.get("function", {}).get("arguments", "{}")
|
||||
|
||||
|
||||
if isinstance(tc_args, str):
|
||||
try:
|
||||
tc_args = json.loads(tc_args)
|
||||
except:
|
||||
pass
|
||||
|
||||
tc_parts.append(
|
||||
f"<tool_call>\n"
|
||||
f'{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n'
|
||||
f"</tool_call>"
|
||||
)
|
||||
|
||||
|
||||
tc_parts.append(f'<tool_call>\n{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n</tool_call>')
|
||||
|
||||
assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts)
|
||||
formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>")
|
||||
elif content:
|
||||
formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
elif role == "tool":
|
||||
formatted_parts.append(
|
||||
f"<|im_start|>user\n"
|
||||
f"<tool_response>\n{content}\n</tool_response><|im_end|>"
|
||||
)
|
||||
|
||||
formatted_parts.append(f"<|im_start|>user\n<tool_response>\n{content}\n</tool_response><|im_end|>")
|
||||
|
||||
formatted_parts.append("<|im_start|>assistant\n")
|
||||
return "\n".join(formatted_parts)
|
||||
|
||||
|
||||
def _parse_tool_calls(self, text: str) -> list[ToolCall]:
|
||||
"""
|
||||
Parse tool calls from response text.
|
||||
|
||||
|
||||
Looks for patterns like:
|
||||
<tool_call>
|
||||
{"name": "tool_name", "arguments": {...}}
|
||||
</tool_call>
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
|
||||
# Find all tool_call blocks
|
||||
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
|
||||
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
tc_data = json.loads(match)
|
||||
name = tc_data.get("name", "")
|
||||
arguments = tc_data.get("arguments", {})
|
||||
|
||||
|
||||
if isinstance(arguments, dict):
|
||||
arguments = json.dumps(arguments)
|
||||
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
@@ -349,17 +343,17 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse tool call JSON: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _extract_content_without_tool_calls(self, text: str) -> str:
|
||||
"""Extract content from response, removing tool_call blocks."""
|
||||
# Remove tool_call blocks
|
||||
cleaned = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL)
|
||||
cleaned = re.sub(r"<tool_call>.*?</tool_call>", "", text, flags=re.DOTALL)
|
||||
# Clean up whitespace
|
||||
cleaned = cleaned.strip()
|
||||
return cleaned
|
||||
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
request_data: dict,
|
||||
@@ -372,7 +366,7 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
) -> AsyncGenerator[LettaMessage | None, None]:
|
||||
"""
|
||||
Execute LLM request using SGLang native endpoint.
|
||||
|
||||
|
||||
This method:
|
||||
1. Formats messages and tools to text using chat template
|
||||
2. Calls SGLang native /generate endpoint
|
||||
@@ -381,20 +375,20 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
5. Converts response to standard format
|
||||
"""
|
||||
self.request_data = request_data
|
||||
|
||||
|
||||
# Get sampling params from request_data
|
||||
sampling_params = {
|
||||
"temperature": request_data.get("temperature", 0.7),
|
||||
"max_new_tokens": request_data.get("max_tokens", 4096),
|
||||
"top_p": request_data.get("top_p", 0.9),
|
||||
}
|
||||
|
||||
|
||||
# Format messages to text (includes tools in prompt)
|
||||
text_input = self._format_messages_to_text(messages, tools)
|
||||
|
||||
|
||||
# Call SGLang native endpoint
|
||||
client = self._get_sglang_client()
|
||||
|
||||
|
||||
try:
|
||||
response = await client.generate(
|
||||
text=text_input,
|
||||
@@ -404,31 +398,31 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
except Exception as e:
|
||||
logger.error(f"SGLang native endpoint error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
|
||||
# Store native response data
|
||||
self.response_data = response
|
||||
|
||||
|
||||
# Extract SGLang native data
|
||||
self.output_ids = response.get("output_ids")
|
||||
# output_token_logprobs is inside meta_info
|
||||
meta_info = response.get("meta_info", {})
|
||||
self.output_token_logprobs = meta_info.get("output_token_logprobs")
|
||||
|
||||
|
||||
# Extract text response
|
||||
text_response = response.get("text", "")
|
||||
|
||||
|
||||
# Remove trailing end token if present
|
||||
if text_response.endswith("<|im_end|>"):
|
||||
text_response = text_response[:-10]
|
||||
|
||||
|
||||
# Parse tool calls from response
|
||||
parsed_tool_calls = self._parse_tool_calls(text_response)
|
||||
|
||||
|
||||
# Extract content (text without tool_call blocks)
|
||||
content_text = self._extract_content_without_tool_calls(text_response)
|
||||
|
||||
|
||||
# Determine finish reason
|
||||
meta_info = response.get("meta_info", {})
|
||||
finish_reason_info = meta_info.get("finish_reason", {})
|
||||
@@ -436,11 +430,11 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
finish_reason = finish_reason_info.get("type", "stop")
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
|
||||
# If we have tool calls, set finish_reason to tool_calls
|
||||
if parsed_tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
|
||||
# Convert to standard ChatCompletionResponse format for compatibility
|
||||
# Build logprobs in OpenAI format from SGLang format
|
||||
logprobs_content = None
|
||||
@@ -458,13 +452,13 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
top_logprobs=[],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
|
||||
# Build chat completion response
|
||||
prompt_tokens = meta_info.get("prompt_tokens", 0)
|
||||
completion_tokens = len(self.output_ids) if self.output_ids else 0
|
||||
|
||||
|
||||
self.chat_completions_response = ChatCompletionResponse(
|
||||
id=meta_info.get("id", "sglang-native"),
|
||||
created=int(time.time()),
|
||||
@@ -486,36 +480,36 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Extract content
|
||||
if content_text:
|
||||
self.content = [TextContent(text=content_text)]
|
||||
else:
|
||||
self.content = None
|
||||
|
||||
|
||||
# No reasoning content from native endpoint
|
||||
self.reasoning_content = None
|
||||
|
||||
|
||||
# Set tool calls
|
||||
self.tool_calls = parsed_tool_calls
|
||||
self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None
|
||||
|
||||
|
||||
# Set logprobs
|
||||
self.logprobs = choice_logprobs
|
||||
|
||||
|
||||
# Extract usage statistics
|
||||
self.usage.step_count = 1
|
||||
self.usage.completion_tokens = completion_tokens
|
||||
self.usage.prompt_tokens = prompt_tokens
|
||||
self.usage.total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
|
||||
self.log_provider_trace(step_id=step_id, actor=actor)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"SGLang native response: {len(self.output_ids or [])} tokens, "
|
||||
f"{len(self.output_token_logprobs or [])} logprobs, "
|
||||
f"{len(parsed_tool_calls)} tool calls"
|
||||
)
|
||||
|
||||
|
||||
yield None
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user