fix: parallel tool calling OpenAI (#2738)
This commit is contained in:
@@ -31,6 +31,9 @@ LETTA_TOOL_MODULE_NAMES = [
|
||||
LETTA_FILES_TOOL_MODULE_NAME,
|
||||
]
|
||||
|
||||
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_ORG_NAME = "default_org"
|
||||
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
# Example full message:
|
||||
|
||||
@@ -217,12 +217,14 @@ class OpenAIClient(LLMClientBase):
|
||||
messages=fill_image_content_in_messages(openai_message_list, messages),
|
||||
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
|
||||
tool_choice=tool_choice,
|
||||
parallel_tool_calls=False if tools else None, # Forcibly disable parallel tool calling
|
||||
user=str(),
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
# NOTE: the reasoners that don't support temperature require 1.0, not None
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else 1.0,
|
||||
)
|
||||
if tools and supports_parallel_tool_calling(model):
|
||||
data.parallel_tool_calls = False
|
||||
|
||||
# always set user id for openai requests
|
||||
if self.actor:
|
||||
data.user = self.actor.id
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
|
||||
class UserBase(LettaBase):
|
||||
@@ -22,7 +22,7 @@ class User(UserBase):
|
||||
"""
|
||||
|
||||
id: str = UserBase.generate_id_field()
|
||||
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
|
||||
organization_id: Optional[str] = Field(DEFAULT_ORG_ID, description="The organization id of the user")
|
||||
name: str = Field(..., description="The name of the user.")
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the user.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -12,14 +13,11 @@ from letta.utils import enforce_types
|
||||
class OrganizationManager:
|
||||
"""Manager class to handle business logic related to Organizations."""
|
||||
|
||||
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_ORG_NAME = "default_org"
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_default_organization_async(self) -> PydanticOrganization:
|
||||
"""Fetch the default organization."""
|
||||
return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID)
|
||||
return await self.get_organization_by_id_async(DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -72,14 +70,14 @@ class OrganizationManager:
|
||||
@trace_method
|
||||
def create_default_organization(self) -> PydanticOrganization:
|
||||
"""Create the default organization."""
|
||||
pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
|
||||
pydantic_org = PydanticOrganization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
|
||||
return self.create_organization(pydantic_org)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_default_organization_async(self) -> PydanticOrganization:
|
||||
"""Create the default organization."""
|
||||
return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
|
||||
return await self.create_organization_async(PydanticOrganization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID))
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
from async_lru import alru_cache
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.user import User as UserModel
|
||||
@@ -10,7 +11,6 @@ from letta.otel.tracing import trace_method
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class UserManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
|
||||
def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
|
||||
"""Create the default user."""
|
||||
with db_registry.session() as session:
|
||||
# Make sure the org id exists
|
||||
@@ -65,7 +65,7 @@ class UserManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_default_actor_async(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
|
||||
async def create_default_actor_async(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
|
||||
"""Create the default user."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Make sure the org id exists
|
||||
@@ -218,7 +218,7 @@ class UserManager:
|
||||
try:
|
||||
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
|
||||
except NoResultFound:
|
||||
return await self.create_default_actor_async(org_id=OrganizationManager.DEFAULT_ORG_ID)
|
||||
return await self.create_default_actor_async(org_id=DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -229,7 +229,7 @@ class UserManager:
|
||||
try:
|
||||
return await self._get_actor_cached(target_id)
|
||||
except NoResultFound:
|
||||
user = await self.create_default_actor_async(org_id=OrganizationManager.DEFAULT_ORG_ID)
|
||||
user = await self.create_default_actor_async(org_id=DEFAULT_ORG_ID)
|
||||
return user
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -29,6 +29,7 @@ from letta.constants import (
|
||||
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
|
||||
BASE_VOICE_SLEEPTIME_TOOLS,
|
||||
BUILTIN_TOOLS,
|
||||
DEFAULT_ORG_ID,
|
||||
FILES_TOOLS,
|
||||
LETTA_TOOL_EXECUTION_DIR,
|
||||
LETTA_TOOL_SET,
|
||||
@@ -81,7 +82,6 @@ from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.settings import tool_settings
|
||||
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
|
||||
from tests.utils import random_string
|
||||
@@ -2724,7 +2724,7 @@ async def test_update_user(server: SyncServer, event_loop):
|
||||
# Adjust name
|
||||
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
|
||||
assert user.organization_id == DEFAULT_ORG_ID
|
||||
|
||||
# Adjust org id
|
||||
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id))
|
||||
|
||||
Reference in New Issue
Block a user