fix: parallel tool calling OpenAI (#2738)

This commit is contained in:
cthomas
2025-06-10 14:27:01 -07:00
committed by GitHub
parent b332ebfa85
commit 5ecd8a706c
6 changed files with 19 additions and 16 deletions

View File

@@ -31,6 +31,9 @@ LETTA_TOOL_MODULE_NAMES = [
LETTA_FILES_TOOL_MODULE_NAME,
]
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
DEFAULT_ORG_NAME = "default_org"
# String in the error message for when the context window is too large
# Example full message:

View File

@@ -217,12 +217,14 @@ class OpenAIClient(LLMClientBase):
messages=fill_image_content_in_messages(openai_message_list, messages),
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
tool_choice=tool_choice,
parallel_tool_calls=False if tools else None, # Forcibly disable parallel tool calling
user=str(),
max_completion_tokens=llm_config.max_tokens,
# NOTE: the reasoners that don't support temperature require 1.0, not None
temperature=llm_config.temperature if supports_temperature_param(model) else 1.0,
)
if tools and supports_parallel_tool_calling(model):
data.parallel_tool_calls = False
# always set user id for openai requests
if self.actor:
data.user = self.actor.id

View File

@@ -3,8 +3,8 @@ from typing import Optional
from pydantic import Field
from letta.constants import DEFAULT_ORG_ID
from letta.schemas.letta_base import LettaBase
from letta.services.organization_manager import OrganizationManager
class UserBase(LettaBase):
@@ -22,7 +22,7 @@ class User(UserBase):
"""
id: str = UserBase.generate_id_field()
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
organization_id: Optional[str] = Field(DEFAULT_ORG_ID, description="The organization id of the user")
name: str = Field(..., description="The name of the user.")
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the user.")

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.otel.tracing import trace_method
@@ -12,14 +13,11 @@ from letta.utils import enforce_types
class OrganizationManager:
"""Manager class to handle business logic related to Organizations."""
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
DEFAULT_ORG_NAME = "default_org"
@enforce_types
@trace_method
async def get_default_organization_async(self) -> PydanticOrganization:
"""Fetch the default organization."""
return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID)
return await self.get_organization_by_id_async(DEFAULT_ORG_ID)
@enforce_types
@trace_method
@@ -72,14 +70,14 @@ class OrganizationManager:
@trace_method
def create_default_organization(self) -> PydanticOrganization:
"""Create the default organization."""
pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
pydantic_org = PydanticOrganization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
return self.create_organization(pydantic_org)
@enforce_types
@trace_method
async def create_default_organization_async(self) -> PydanticOrganization:
"""Create the default organization."""
return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
return await self.create_organization_async(PydanticOrganization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID))
@enforce_types
@trace_method

View File

@@ -3,6 +3,7 @@ from typing import List, Optional
from async_lru import alru_cache
from sqlalchemy import select, text
from letta.constants import DEFAULT_ORG_ID
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.user import User as UserModel
@@ -10,7 +11,6 @@ from letta.otel.tracing import trace_method
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.services.organization_manager import OrganizationManager
from letta.utils import enforce_types
@@ -43,7 +43,7 @@ class UserManager:
@enforce_types
@trace_method
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
with db_registry.session() as session:
# Make sure the org id exists
@@ -65,7 +65,7 @@ class UserManager:
@enforce_types
@trace_method
async def create_default_actor_async(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
async def create_default_actor_async(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
async with db_registry.async_session() as session:
# Make sure the org id exists
@@ -218,7 +218,7 @@ class UserManager:
try:
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
except NoResultFound:
return await self.create_default_actor_async(org_id=OrganizationManager.DEFAULT_ORG_ID)
return await self.create_default_actor_async(org_id=DEFAULT_ORG_ID)
@enforce_types
@trace_method
@@ -229,7 +229,7 @@ class UserManager:
try:
return await self._get_actor_cached(target_id)
except NoResultFound:
user = await self.create_default_actor_async(org_id=OrganizationManager.DEFAULT_ORG_ID)
user = await self.create_default_actor_async(org_id=DEFAULT_ORG_ID)
return user
@enforce_types

View File

@@ -29,6 +29,7 @@ from letta.constants import (
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
BASE_VOICE_SLEEPTIME_TOOLS,
BUILTIN_TOOLS,
DEFAULT_ORG_ID,
FILES_TOOLS,
LETTA_TOOL_EXECUTION_DIR,
LETTA_TOOL_SET,
@@ -81,7 +82,6 @@ from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.organization_manager import OrganizationManager
from letta.settings import tool_settings
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
from tests.utils import random_string
@@ -2724,7 +2724,7 @@ async def test_update_user(server: SyncServer, event_loop):
# Adjust name
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b))
assert user.name == user_name_b
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
assert user.organization_id == DEFAULT_ORG_ID
# Adjust org id
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id))