fix: lmstudio support for qwen and llama
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local> Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
@@ -101,6 +101,8 @@ def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
|
||||
return True
|
||||
if llm_config.handle and "vllm" in llm_config.handle:
|
||||
return True
|
||||
if llm_config.compatibility_type == "mlx":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ class LLMConfig(BaseModel):
|
||||
None, # Can also deafult to 0.0?
|
||||
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. From OpenAI: Number between -2.0 and 2.0.",
|
||||
)
|
||||
compatibility_type: Optional[Literal["gguf", "mlx"]] = Field(None, description="The framework compatibility type for the model.")
|
||||
|
||||
# FIXME hack to silence pydantic protected namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -45,6 +45,12 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
continue
|
||||
model_name, context_window_size = check
|
||||
|
||||
if "compatibility_type" in model:
|
||||
compatibility_type = model["compatibility_type"]
|
||||
else:
|
||||
warnings.warn(f"LMStudio OpenAI model missing 'compatibility_type' field: {model}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
@@ -52,6 +58,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
compatibility_type=compatibility_type,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Set
|
||||
|
||||
@@ -351,21 +352,40 @@ def initialize_message_sequence(
|
||||
first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
llm_config = agent_state.llm_config
|
||||
uuid_str = str(uuid.uuid4())
|
||||
|
||||
# Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters
|
||||
tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str
|
||||
|
||||
if agent_state.agent_type == AgentType.sleeptime_agent:
|
||||
initial_boot_messages = []
|
||||
elif agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone)
|
||||
elif llm_config.model is not None and "gpt-3.5" in llm_config.model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id)
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone)
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
)
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id)
|
||||
|
||||
# Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls
|
||||
if llm_config.provider_name == "lmstudio_openai":
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
)
|
||||
else:
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
messages = [
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
@@ -13,7 +12,7 @@ from .helpers.datetime_helpers import get_local_time
|
||||
from .helpers.json_helpers import json_dumps
|
||||
|
||||
|
||||
def get_initial_boot_messages(version, timezone):
|
||||
def get_initial_boot_messages(version, timezone, tool_call_id):
|
||||
if version == "startup":
|
||||
initial_boot_message = INITIAL_BOOT_MESSAGE
|
||||
messages = [
|
||||
@@ -21,7 +20,6 @@ def get_initial_boot_messages(version, timezone):
|
||||
]
|
||||
|
||||
elif version == "startup_with_send_message":
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
messages = [
|
||||
# first message includes both inner monologue and function call to send_message
|
||||
{
|
||||
@@ -53,7 +51,6 @@ def get_initial_boot_messages(version, timezone):
|
||||
]
|
||||
|
||||
elif version == "startup_with_send_message_gpt35":
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
messages = [
|
||||
# first message includes both inner monologue and function call to send_message
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user