diff --git a/letta/cli/cli.py b/letta/cli/cli.py index c6e435a8..3852e2b4 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -218,9 +218,7 @@ def run( ) # create agent - tools = [ - server.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=client.user_id) for tool_name in agent_state.tools - ] + tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tools] letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools) else: # create new agent @@ -300,7 +298,7 @@ def run( ) assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}" typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE) - tools = [server.tool_manager.get_tool_by_name_and_user_id(tool_name, user_id=client.user_id) for tool_name in agent_state.tools] + tools = [server.tool_manager.get_tool_by_name(tool_name, actor=client.user) for tool_name in agent_state.tools] letta_agent = Agent( interface=interface(), diff --git a/letta/client/client.py b/letta/client/client.py index fd2c49eb..6d02d74d 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1546,6 +1546,9 @@ class LocalClient(AbstractClient): # get default user self.user_id = self.server.user_manager.DEFAULT_USER_ID + self.user = self.server.get_user_or_default(self.user_id) + self.organization = self.server.get_organization_or_default(self.org_id) + # agents def list_agents(self) -> List[AgentState]: self.interface.clear() @@ -1648,7 +1651,7 @@ class LocalClient(AbstractClient): llm_config=llm_config if llm_config else self._default_llm_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config, ), - user_id=self.user_id, + actor=self.user, ) return agent_state @@ -1720,7 +1723,7 @@ class LocalClient(AbstractClient): message_ids=message_ids, memory=memory, ), - user_id=self.user_id, + actor=self.user, ) return agent_state @@ -2198,24 +2201,22 @@ class LocalClient(AbstractClient): def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: tool_create = ToolCreate.from_langchain( langchain_tool=langchain_tool, - user_id=self.user_id, organization_id=self.org_id, additional_imports_module_attr_map=additional_imports_module_attr_map, ) - return self.server.tool_manager.create_or_update_tool(tool_create) + return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user) def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: tool_create = ToolCreate.from_crewai( crewai_tool=crewai_tool, additional_imports_module_attr_map=additional_imports_module_attr_map, - user_id=self.user_id, organization_id=self.org_id, ) - return self.server.tool_manager.create_or_update_tool(tool_create) + return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user) def load_composio_tool(self, action: "ActionType") -> Tool: - tool_create = ToolCreate.from_composio(action=action, user_id=self.user_id, organization_id=self.org_id) - return self.server.tool_manager.create_or_update_tool(tool_create) + tool_create = ToolCreate.from_composio(action=action, organization_id=self.org_id) + return self.server.tool_manager.create_or_update_tool(tool_create, actor=self.user) # TODO: Use the above function `add_tool` here as there is duplicate logic def create_tool( @@ -2250,14 +2251,13 @@ class LocalClient(AbstractClient): # call server function return self.server.tool_manager.create_or_update_tool( ToolCreate( - user_id=self.user_id, - organization_id=self.org_id, source_type=source_type, source_code=source_code, name=name, tags=tags, terminal=terminal, ), + actor=self.user, ) def update_tool( @@ -2289,7 +2289,7 @@ class LocalClient(AbstractClient): # Filter out any None values from the dictionary update_data = {key: value for key, value in update_data.items() if value is not None} - return self.server.tool_manager.update_tool_by_id(id, ToolUpdate(**update_data)) + return self.server.tool_manager.update_tool_by_id(tool_id=id, tool_update=ToolUpdate(**update_data), actor=self.user) def list_tools(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: """ @@ -2298,7 +2298,7 @@ class LocalClient(AbstractClient): Returns: tools (List[Tool]): List of tools """ - return self.server.tool_manager.list_tools_for_org(cursor=cursor, limit=limit, organization_id=self.org_id) + return self.server.tool_manager.list_tools(cursor=cursor, limit=limit, actor=self.user) def get_tool(self, id: str) -> Optional[Tool]: """ @@ -2310,7 +2310,7 @@ class LocalClient(AbstractClient): Returns: tool (Tool): Tool """ - return self.server.tool_manager.get_tool_by_id(id) + return self.server.tool_manager.get_tool_by_id(id, actor=self.user) def delete_tool(self, id: str): """ @@ -2319,7 +2319,7 @@ class LocalClient(AbstractClient): Args: id (str): ID of the tool """ - return self.server.tool_manager.delete_tool_by_id(id) + return self.server.tool_manager.delete_tool_by_id(id, user_id=self.user_id) def get_tool_id(self, name: str) -> Optional[str]: """ @@ -2331,7 +2331,7 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the tool (`None` if not found) """ - tool = self.server.tool_manager.get_tool_by_name_and_org_id(tool_name=name, organization_id=self.org_id) + tool = self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user) return tool.id def load_data(self, connector: DataConnector, source_name: str): diff --git a/letta/orm/base.py b/letta/orm/base.py index 61f7575d..d8a84751 100644 --- a/letta/orm/base.py +++ b/letta/orm/base.py @@ -1,9 +1,7 @@ from datetime import datetime from typing import Optional -from uuid import UUID -from sqlalchemy import UUID as SQLUUID -from sqlalchemy import Boolean, DateTime, func, text +from sqlalchemy import Boolean, DateTime, String, func, text from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -25,6 +23,13 @@ class CommonSqlalchemyMetaMixins(Base): updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), server_default=func.now(), server_onupdate=func.now()) is_deleted: Mapped[bool] = mapped_column(Boolean, server_default=text("FALSE")) + def _set_created_and_updated_by_fields(self, actor_id: str) -> None: + """Populate created_by_id and last_updated_by_id based on actor.""" + if not self.created_by_id: + self.created_by_id = actor_id + # Always set the last_updated_by_id when updating + self.last_updated_by_id = actor_id + @declared_attr def _created_by_id(cls): return cls._user_by_id() @@ -38,7 +43,7 @@ class CommonSqlalchemyMetaMixins(Base): """a flexible non-constrained record of a user. This way users can get added, deleted etc without history freaking out """ - return mapped_column(SQLUUID(), nullable=True) + return mapped_column(String, nullable=True) @property def last_updated_by_id(self) -> Optional[str]: @@ -72,4 +77,4 @@ class CommonSqlalchemyMetaMixins(Base): return prefix, id_ = value.split("-", 1) assert prefix == "user", f"{prefix} is not a valid id prefix for a user id" - setattr(self, full_prop, UUID(id_)) + setattr(self, full_prop, id_) diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index a23d03da..59c90d94 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,5 +1,5 @@ -from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union -from uuid import UUID, uuid4 +from typing import TYPE_CHECKING, List, Literal, Optional, Type +from uuid import uuid4 from humps import depascalize from sqlalchemy import Boolean, String, select @@ -88,7 +88,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): def read( cls, db_session: "Session", - identifier: Union[str, UUID], + identifier: Optional[str] = None, actor: Optional["User"] = None, access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], **kwargs, @@ -105,19 +105,29 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): Raises: NoResultFound: if the object is not found """ - del kwargs # arity for more complex reads - identifier = cls.get_uid_from_identifier(identifier) - query = select(cls).where(cls._id == identifier) - # if actor: - # query = cls.apply_access_predicate(query, actor, access) + # Start the query + query = select(cls) + + # If an identifier is provided, add it to the query conditions + if identifier is not None: + identifier = cls.get_uid_from_identifier(identifier) + query = query.where(cls._id == identifier) + + if kwargs: + query = query.filter_by(**kwargs) + + if actor: + query = cls.apply_access_predicate(query, actor, access) + if hasattr(cls, "is_deleted"): query = query.where(cls.is_deleted == False) if found := db_session.execute(query).scalar(): return found raise NoResultFound(f"{cls.__name__} with id {identifier} not found") - def create(self, db_session: "Session") -> Type["SqlalchemyBase"]: - # self._infer_organization(db_session) + def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: + if actor: + self._set_created_and_updated_by_fields(actor.id) with db_session as session: session.add(self) @@ -125,11 +135,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): session.refresh(self) return self - def delete(self, db_session: "Session") -> Type["SqlalchemyBase"]: + def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: + if actor: + self._set_created_and_updated_by_fields(actor.id) + self.is_deleted = True return self.update(db_session) - def update(self, db_session: "Session") -> Type["SqlalchemyBase"]: + def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: + if actor: + self._set_created_and_updated_by_fields(actor.id) + with db_session as session: session.add(self) session.commit() @@ -137,39 +153,28 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return self @classmethod - def read_or_create(cls, *, db_session: "Session", **kwargs) -> Type["SqlalchemyBase"]: - """get an instance by search criteria or create it if it doesn't exist""" - try: - return cls.read(db_session=db_session, identifier=kwargs.get("id", None)) - except NoResultFound: - clean_kwargs = {k: v for k, v in kwargs.items() if k in cls.__table__.columns} - return cls(**clean_kwargs).create(db_session=db_session) - - # TODO: Add back later when access predicates are actually important - # The idea behind this is that you can add a WHERE clause restricting the actions you can take, e.g. R/W - # @classmethod - # def apply_access_predicate( - # cls, - # query: "Select", - # actor: "User", - # access: List[Literal["read", "write", "admin"]], - # ) -> "Select": - # """applies a WHERE clause restricting results to the given actor and access level - # Args: - # query: The initial sqlalchemy select statement - # actor: The user acting on the query. **Note**: this is called 'actor' to identify the - # person or system acting. Users can act on users, making naming very sticky otherwise. - # access: - # what mode of access should the query restrict to? This will be used with granular permissions, - # but because of how it will impact every query we want to be explicitly calling access ahead of time. - # Returns: - # the sqlalchemy select statement restricted to the given access. - # """ - # del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment - # org_uid = getattr(actor, "_organization_id", getattr(actor.organization, "_id", None)) - # if not org_uid: - # raise ValueError("object %s has no organization accessor", actor) - # return query.where(cls._organization_id == org_uid, cls.is_deleted == False) + def apply_access_predicate( + cls, + query: "Select", + actor: "User", + access: List[Literal["read", "write", "admin"]], + ) -> "Select": + """applies a WHERE clause restricting results to the given actor and access level + Args: + query: The initial sqlalchemy select statement + actor: The user acting on the query. **Note**: this is called 'actor' to identify the + person or system acting. Users can act on users, making naming very sticky otherwise. + access: + what mode of access should the query restrict to? This will be used with granular permissions, + but because of how it will impact every query we want to be explicitly calling access ahead of time. + Returns: + the sqlalchemy select statement restricted to the given access. + """ + del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment + org_id = getattr(actor, "organization_id", None) + if not org_id: + raise ValueError(f"object {actor} has no organization accessor") + return query.where(cls._organization_id == cls.get_uid_from_identifier(org_id, indifferent=True), cls.is_deleted == False) @property def __pydantic_model__(self) -> Type["BaseModel"]: @@ -183,21 +188,3 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): """Deprecated accessor for to_pydantic""" logger.warning("to_record is deprecated, use to_pydantic instead.") return self.to_pydantic() - - def _infer_organization(self, db_session: "Session") -> None: - """🪄 MAGIC ALERT! 🪄 - Because so much of the original API is centered around user scopes, - this allows us to continue with that scope and then infer the org from the creating user. - - IF a created_by_id is set, we will use that to infer the organization and magic set it at create time! - If not do nothing to the object. Mutates in place. - """ - if self.created_by_id and hasattr(self, "_organization_id"): - try: - from letta.orm.user import User # to avoid circular import - - created_by = User.read(db_session, self.created_by_id) - except NoResultFound: - logger.warning(f"User {self.created_by_id} not found, unable to infer organization.") - return - self._organization_id = created_by._organization_id diff --git a/letta/orm/tool.py b/letta/orm/tool.py index 158ed235..b8f2ba1d 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -5,18 +5,15 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship # TODO everything in functions should live in this model from letta.orm.enums import ToolSourceType -from letta.orm.mixins import OrganizationMixin, UserMixin +from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.tool import Tool as PydanticTool if TYPE_CHECKING: - pass - from letta.orm.organization import Organization - from letta.orm.user import User -class Tool(SqlalchemyBase, OrganizationMixin, UserMixin): +class Tool(SqlalchemyBase, OrganizationMixin): """Represents an available tool that the LLM can invoke. NOTE: polymorphic inheritance makes more sense here as a TODO. We want a superset of tools @@ -29,10 +26,7 @@ class Tool(SqlalchemyBase, OrganizationMixin, UserMixin): # Add unique constraint on (name, _organization_id) # An organization should not have multiple tools with the same name - __table_args__ = ( - UniqueConstraint("name", "_organization_id", name="uix_name_organization"), - UniqueConstraint("name", "_user_id", name="uix_name_user"), - ) + __table_args__ = (UniqueConstraint("name", "_organization_id", name="uix_name_organization"),) name: Mapped[str] = mapped_column(doc="The display name of the tool.") description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.") @@ -48,7 +42,4 @@ class Tool(SqlalchemyBase, OrganizationMixin, UserMixin): # This was an intentional decision by Sarah # relationships - # TODO: Possibly add in user in the future - # This will require some more thought and justification to add this in. - user: Mapped["User"] = relationship("User", back_populates="tools", lazy="selectin") organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") diff --git a/letta/orm/user.py b/letta/orm/user.py index 31bd40f8..dfa0acc9 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -8,7 +8,6 @@ from letta.schemas.user import User as PydanticUser if TYPE_CHECKING: from letta.orm.organization import Organization - from letta.orm.tool import Tool class User(SqlalchemyBase, OrganizationMixin): @@ -21,7 +20,6 @@ class User(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="users") - tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="user", cascade="all, delete-orphan") # TODO: Add this back later potentially # agents: Mapped[List["Agent"]] = relationship( diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index e31a3856..9f54f25b 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -11,7 +11,6 @@ from letta.functions.schema_generator import generate_schema_from_args_schema from letta.schemas.letta_base import LettaBase from letta.schemas.openai.chat_completions import ToolCall from letta.services.organization_manager import OrganizationManager -from letta.services.user_manager import UserManager class BaseTool(LettaBase): @@ -35,7 +34,6 @@ class Tool(BaseTool): description: Optional[str] = Field(None, description="The description of the tool.") source_type: Optional[str] = Field(None, description="The type of the source code.") module: Optional[str] = Field(None, description="The module of the function.") - user_id: str = Field(..., description="The unique identifier of the user associated with the tool.") organization_id: str = Field(..., description="The unique identifier of the organization associated with the tool.") name: str = Field(..., description="The name of the function.") tags: List[str] = Field(..., description="Metadata tags.") @@ -44,6 +42,10 @@ class Tool(BaseTool): source_code: str = Field(..., description="The source code of the function.") json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.") + # metadata fields + created_by_id: str = Field(..., description="The id of the user that made this Tool.") + last_updated_by_id: str = Field(..., description="The id of the user that made this Tool.") + def to_dict(self): """ Convert tool into OpenAI representation. @@ -58,11 +60,6 @@ class Tool(BaseTool): class ToolCreate(LettaBase): - user_id: str = Field(UserManager.DEFAULT_USER_ID, description="The user that this tool belongs to. Defaults to the default user ID.") - organization_id: str = Field( - OrganizationManager.DEFAULT_ORG_ID, - description="The organization that this tool belongs to. Defaults to the default organization ID.", - ) name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).") description: Optional[str] = Field(None, description="The description of the tool.") tags: List[str] = Field([], description="Metadata tags.") @@ -75,9 +72,7 @@ class ToolCreate(LettaBase): terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).") @classmethod - def from_composio( - cls, action: "ActionType", user_id: str = UserManager.DEFAULT_USER_ID, organization_id: str = OrganizationManager.DEFAULT_ORG_ID - ) -> "ToolCreate": + def from_composio(cls, action: "ActionType", organization_id: str = OrganizationManager.DEFAULT_ORG_ID) -> "ToolCreate": """ Class method to create an instance of Letta-compatible Composio Tool. Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio @@ -106,8 +101,6 @@ class ToolCreate(LettaBase): json_schema = generate_schema_from_args_schema(composio_tool.args_schema, name=wrapper_func_name, description=description) return cls( - user_id=user_id, - organization_id=organization_id, name=wrapper_func_name, description=description, source_type=source_type, @@ -121,7 +114,6 @@ class ToolCreate(LettaBase): cls, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None, - user_id: str = UserManager.DEFAULT_USER_ID, organization_id: str = OrganizationManager.DEFAULT_ORG_ID, ) -> "ToolCreate": """ @@ -142,8 +134,6 @@ class ToolCreate(LettaBase): json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description) return cls( - user_id=user_id, - organization_id=organization_id, name=wrapper_func_name, description=description, source_type=source_type, @@ -157,7 +147,6 @@ class ToolCreate(LettaBase): cls, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None, - user_id: str = UserManager.DEFAULT_USER_ID, organization_id: str = OrganizationManager.DEFAULT_ORG_ID, ) -> "ToolCreate": """ @@ -176,8 +165,6 @@ class ToolCreate(LettaBase): json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description) return cls( - user_id=user_id, - organization_id=organization_id, name=wrapper_func_name, description=description, source_type=source_type, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 15ea2109..ca4afe06 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -84,7 +84,7 @@ def create_agent( blocks = agent.memory.get_blocks() agent.memory = BasicBlockMemory(blocks=blocks) - return server.create_agent(agent, user_id=actor.id) + return server.create_agent(agent, actor=actor) @router.patch("/{agent_id}", response_model=AgentState, operation_id="update_agent") @@ -96,9 +96,7 @@ def update_agent( ): """Update an exsiting agent""" actor = server.get_user_or_default(user_id=user_id) - - update_agent.id = agent_id - return server.update_agent(update_agent, user_id=actor.id) + return server.update_agent(update_agent, actor=actor) @router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 35f41a26..f12fcd18 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -26,11 +26,13 @@ def delete_tool( def get_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a tool by ID """ - tool = server.tool_manager.get_tool_by_id(tool_id=tool_id) + actor = server.get_user_or_default(user_id=user_id) + tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) if tool is None: # return 404 error raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.") @@ -49,7 +51,7 @@ def get_tool_id( actor = server.get_user_or_default(user_id=user_id) try: - tool = server.tool_manager.get_tool_by_name_and_org_id(tool_name=tool_name, organization_id=actor.organization_id) + tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) return tool.id except NoResultFound: raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} and organization id {actor.organization_id} not found.") @@ -67,7 +69,7 @@ def list_tools( """ try: actor = server.get_user_or_default(user_id=user_id) - return server.tool_manager.list_tools_for_org(organization_id=actor.organization_id, cursor=cursor, limit=limit) + return server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit) except Exception as e: # Log or print the full exception here for debugging print(f"Error occurred: {e}") @@ -85,13 +87,9 @@ def create_tool( """ # Derive user and org id from actor actor = server.get_user_or_default(user_id=user_id) - request.organization_id = actor.organization_id - request.user_id = actor.id # Send request to create the tool - return server.tool_manager.create_or_update_tool( - tool_create=request, - ) + return server.tool_manager.create_or_update_tool(tool_create=request, actor=actor) @router.patch("/{tool_id}", response_model=Tool, operation_id="update_tool") @@ -104,4 +102,5 @@ def update_tool( """ Update an existing tool """ - return server.tool_manager.update_tool_by_id(tool_id, request) + actor = server.get_user_or_default(user_id=user_id) + return server.tool_manager.update_tool_by_id(tool_id, actor.id, request) diff --git a/letta/server/server.py b/letta/server/server.py index 67a3b498..00e1dc37 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -37,6 +37,7 @@ from letta.log import get_logger from letta.memory import get_memory_functions from letta.metadata import Base, MetadataStore from letta.o1_agent import O1Agent +from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.providers import ( AnthropicProvider, @@ -73,6 +74,7 @@ from letta.schemas.memory import ( RecallMemorySummary, ) from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage +from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate @@ -251,12 +253,12 @@ class SyncServer(Server): self.default_org = self.organization_manager.create_default_organization() self.default_user = self.user_manager.create_default_user() self.add_default_blocks(self.default_user.id) - self.tool_manager.add_default_tools(module_name="base", user_id=self.default_user.id, org_id=self.default_org.id) + self.tool_manager.add_default_tools(module_name="base", actor=self.default_user) # If there is a default org/user # This logic may have to change in the future if settings.load_default_external_tools: - self.add_default_external_tools(user_id=self.default_user.id, org_id=self.default_org.id) + self.add_default_external_tools(actor=self.default_user) # collect providers (always has Letta as a default) self._enabled_providers: List[Provider] = [LettaProvider()] @@ -345,10 +347,10 @@ class SyncServer(Server): } ) - def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent: + def _load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Loads a saved agent into memory (if it doesn't exist, throw an error)""" - assert isinstance(user_id, str), user_id assert isinstance(agent_id, str), agent_id + user_id = actor.id # If an interface isn't specified, use the default if interface is None: @@ -365,7 +367,7 @@ class SyncServer(Server): logger.debug(f"Creating an agent object") tool_objs = [] for name in agent_state.tools: - tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=name, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_name(tool_name=name, actor=actor) if not tool_obj: logger.exception(f"Tool {name} does not exist for user {user_id}") raise ValueError(f"Tool {name} does not exist for user {user_id}") @@ -396,13 +398,14 @@ class SyncServer(Server): if not agent_state: raise ValueError(f"Agent does not exist") user_id = agent_state.user_id + actor = self.user_manager.get_user_by_id(user_id) logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}") # TODO: consider disabling loading cached agents due to potential concurrency issues letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id) if not letta_agent: logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}") - letta_agent = self._load_agent(user_id=user_id, agent_id=agent_id) + letta_agent = self._load_agent(agent_id=agent_id, actor=actor) return letta_agent def _step( @@ -759,11 +762,12 @@ class SyncServer(Server): def create_agent( self, request: CreateAgent, - user_id: str, + actor: User, # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: """Create a new agent using a config""" + user_id = actor.id if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -801,7 +805,7 @@ class SyncServer(Server): tool_objs = [] if request.tools: for tool_name in request.tools: - tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) tool_objs.append(tool_obj) assert request.memory is not None @@ -822,9 +826,8 @@ class SyncServer(Server): source_type=source_type, tags=tags, json_schema=json_schema, - user_id=user_id, - organization_id=user.organization_id, - ) + ), + actor=actor, ) tool_objs.append(tool) if not request.tools: @@ -887,11 +890,14 @@ class SyncServer(Server): def update_agent( self, request: UpdateAgentState, - user_id: str, + actor: User, ): """Update the agents core memory block, return the new state""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") + try: + self.user_manager.get_user_by_id(user_id=actor.id) + except Exception: + raise ValueError(f"User user_id={actor.id} does not exist") + if self.ms.get_agent(agent_id=request.id) is None: raise ValueError(f"Agent agent_id={request.id} does not exist") @@ -902,7 +908,7 @@ class SyncServer(Server): if request.memory: assert isinstance(request.memory, Memory), type(request.memory) new_memory_contents = request.memory.to_flat_dict() - _ = self.update_agent_core_memory(user_id=user_id, agent_id=request.id, new_memory_contents=new_memory_contents) + _ = self.update_agent_core_memory(user_id=actor.id, agent_id=request.id, new_memory_contents=new_memory_contents) # update the system prompt if request.system: @@ -922,7 +928,7 @@ class SyncServer(Server): # (1) get tools + make sure they exist tool_objs = [] for tool_name in request.tools: - tool_obj = self.tool_manager.get_tool_by_name_and_user_id(tool_name=tool_name, user_id=user_id) + tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) assert tool_obj, f"Tool {tool_name} does not exist" tool_objs.append(tool_obj) @@ -968,8 +974,11 @@ class SyncServer(Server): user_id: str, ): """Add tools from an existing agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: + try: + user = self.user_manager.get_user_by_id(user_id=user_id) + except NoResultFound: raise ValueError(f"User user_id={user_id} does not exist") + if self.ms.get_agent(agent_id=agent_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -978,12 +987,12 @@ class SyncServer(Server): # Get all the tool objects from the request tool_objs = [] - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user) assert tool_obj, f"Tool with id={tool_id} does not exist" tool_objs.append(tool_obj) for tool in letta_agent.tools: - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user) assert tool_obj, f"Tool with id={tool.id} does not exist" # If it's not the already added tool @@ -1007,8 +1016,11 @@ class SyncServer(Server): user_id: str, ): """Remove tools from an existing agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: + try: + user = self.user_manager.get_user_by_id(user_id=user_id) + except NoResultFound: raise ValueError(f"User user_id={user_id} does not exist") + if self.ms.get_agent(agent_id=agent_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1018,7 +1030,7 @@ class SyncServer(Server): # Get all the tool_objs tool_objs = [] for tool in letta_agent.tools: - tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id) + tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=user) assert tool_obj, f"Tool with id={tool.id} does not exist" # If it's not the tool we want to remove @@ -1733,7 +1745,7 @@ class SyncServer(Server): return sources_with_metadata - def add_default_external_tools(self, user_id: str, org_id: str) -> bool: + def add_default_external_tools(self, actor: User) -> bool: """Add default langchain tools. Return true if successful, false otherwise.""" success = True tool_creates = ToolCreate.load_default_langchain_tools() + ToolCreate.load_default_crewai_tools() @@ -1741,7 +1753,7 @@ class SyncServer(Server): tool_creates += ToolCreate.load_default_composio_tools() for tool_create in tool_creates: try: - self.tool_manager.create_or_update_tool(tool_create) + self.tool_manager.create_or_update_tool(tool_create, actor=actor) except Exception as e: warnings.warn(f"An error occurred while creating tool {tool_create}: {e}") warnings.warn(traceback.format_exc()) @@ -1843,6 +1855,16 @@ class SyncServer(Server): except ValueError: raise HTTPException(status_code=404, detail=f"User with id {user_id} not found") + def get_organization_or_default(self, org_id: Optional[str]) -> Organization: + """Get the organization object for org_id if it exists, otherwise return the default organization object""" + if org_id is None: + org_id = self.organization_manager.DEFAULT_ORG_ID + + try: + return self.organization_manager.get_organization_by_id(org_id=org_id) + except ValueError: + raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") + def list_llm_models(self) -> List[LLMConfig]: """List available models""" diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 2f939613..e39d7462 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -9,9 +9,9 @@ from letta.functions.functions import derive_openai_json_schema, load_function_s from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.orm.tool import Tool as ToolModel -from letta.orm.user import User as UserModel from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolCreate, ToolUpdate +from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -25,7 +25,7 @@ class ToolManager: self.session_maker = db_context @enforce_types - def create_or_update_tool(self, tool_create: ToolCreate) -> PydanticTool: + def create_or_update_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" # Derive json_schema derived_json_schema = tool_create.json_schema or derive_openai_json_schema(tool_create) @@ -34,105 +34,72 @@ class ToolManager: try: # NOTE: We use the organization id here # This is important, because even if it's a different user, adding the same tool to the org should not happen - tool = self.get_tool_by_name_and_org_id(tool_name=derived_name, organization_id=tool_create.organization_id) + tool = self.get_tool_by_name(tool_name=derived_name, actor=actor) # Put to dict and remove fields that should not be reset - update_data = tool_create.model_dump(exclude={"user_id", "organization_id", "module", "terminal"}, exclude_unset=True) + update_data = tool_create.model_dump(exclude={"module", "terminal"}, exclude_unset=True) # Remove redundant update fields update_data = {key: value for key, value in update_data.items() if getattr(tool, key) != value} # If there's anything to update if update_data: - self.update_tool_by_id(tool.id, ToolUpdate(**update_data)) + self.update_tool_by_id(tool.id, ToolUpdate(**update_data), actor) else: warnings.warn( - f"`create_or_update_tool` was called with user_id={tool_create.user_id}, organization_id={tool_create.organization_id}, name={tool_create.name}, but found existing tool with nothing to update." + f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={tool_create.name}, but found existing tool with nothing to update." ) except NoResultFound: tool_create.json_schema = derived_json_schema tool_create.name = derived_name - tool = self.create_tool(tool_create) + tool = self.create_tool(tool_create, actor=actor) return tool @enforce_types - def create_tool(self, tool_create: ToolCreate) -> PydanticTool: + def create_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" # Create the tool with self.session_maker() as session: - # Include all fields except 'terminal' (which is not part of ToolModel) at the moment + # Include all fields except `terminal` (which is not part of ToolModel) at the moment create_data = tool_create.model_dump(exclude={"terminal"}) - tool = ToolModel(**create_data) # Unpack everything directly into ToolModel - tool.create(session) + tool = ToolModel(**create_data, organization_id=actor.organization_id) # Unpack everything directly into ToolModel + tool.create(session, actor=actor) return tool.to_pydantic() @enforce_types - def get_tool_by_id(self, tool_id: str) -> PydanticTool: + def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool: """Fetch a tool by its ID.""" with self.session_maker() as session: - try: - # Retrieve tool by id using the Tool model's read method - tool = ToolModel.read(db_session=session, identifier=tool_id) - # Convert the SQLAlchemy Tool object to PydanticTool - return tool.to_pydantic() - except NoResultFound: - raise ValueError(f"Tool with id {tool_id} not found.") + # Retrieve tool by id using the Tool model's read method + tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) + # Convert the SQLAlchemy Tool object to PydanticTool + return tool.to_pydantic() @enforce_types - def get_tool_by_name_and_user_id(self, tool_name: str, user_id: str) -> PydanticTool: - """Retrieve a tool by its name and organization_id.""" + def get_tool_by_name(self, tool_name: str, actor: PydanticUser): + """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" with self.session_maker() as session: - # Use the list method to apply filters - results = ToolModel.list(db_session=session, name=tool_name, _user_id=UserModel.get_uid_from_identifier(user_id)) - - # Ensure only one result is returned (since there is a unique constraint) - if not results: - raise NoResultFound(f"Tool with name {tool_name} and user_id {user_id} not found.") - - if len(results) > 1: - raise RuntimeError( - f"Multiple tools with name {tool_name} and user_id {user_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error." - ) - - # Return the single result - return results[0] + tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) + return tool.to_pydantic() @enforce_types - def get_tool_by_name_and_org_id(self, tool_name: str, organization_id: str) -> PydanticTool: - """Retrieve a tool by its name and organization_id.""" - with self.session_maker() as session: - # Use the list method to apply filters - results = ToolModel.list( - db_session=session, name=tool_name, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id) - ) - - # Ensure only one result is returned (since there is a unique constraint) - if not results: - raise NoResultFound(f"Tool with name {tool_name} and organization_id {organization_id} not found.") - - if len(results) > 1: - raise RuntimeError( - f"Multiple tools with name {tool_name} and organization_id {organization_id} were found. This is a serious error, and means that our table does not have uniqueness constraints properly set up. Please reach out to the letta development team if you see this error." - ) - - # Return the single result - return results[0] - - @enforce_types - def list_tools_for_org(self, organization_id: str, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: + def list_tools(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination using cursor and limit.""" with self.session_maker() as session: tools = ToolModel.list( - db_session=session, cursor=cursor, limit=limit, _organization_id=OrganizationModel.get_uid_from_identifier(organization_id) + db_session=session, + cursor=cursor, + limit=limit, + _organization_id=OrganizationModel.get_uid_from_identifier(actor.organization_id), ) return [tool.to_pydantic() for tool in tools] @enforce_types - def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate) -> None: + def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> None: """Update a tool by its ID with the given ToolUpdate object.""" with self.session_maker() as session: # Fetch the tool by ID - tool = ToolModel.read(db_session=session, identifier=tool_id) + tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) # Update tool attributes with only the fields that were explicitly set update_data = tool_update.model_dump(exclude_unset=True, exclude_none=True) @@ -140,20 +107,20 @@ class ToolManager: setattr(tool, key, value) # Save the updated tool to the database - tool.update(db_session=session) + tool.update(db_session=session, actor=actor) @enforce_types - def delete_tool_by_id(self, tool_id: str) -> None: + def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: """Delete a tool by its ID.""" with self.session_maker() as session: try: tool = ToolModel.read(db_session=session, identifier=tool_id) - tool.delete(db_session=session) + tool.delete(db_session=session, actor=actor) except NoResultFound: raise ValueError(f"Tool with id {tool_id} not found.") @enforce_types - def add_default_tools(self, user_id: str, org_id: str, module_name="base"): + def add_default_tools(self, actor: PydanticUser, module_name="base"): """Add default tools in {module_name}.py""" full_module_name = f"letta.functions.function_sets.{module_name}" try: @@ -187,7 +154,6 @@ class ToolManager: module=schema["module"], source_code=source_code, json_schema=schema["json_schema"], - organization_id=org_id, - user_id=user_id, ), + actor=actor, ) diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 10116c0b..c9f6b166 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -85,11 +85,8 @@ class UserManager: def get_user_by_id(self, user_id: str) -> PydanticUser: """Fetch a user by ID.""" with self.session_maker() as session: - try: - user = UserModel.read(db_session=session, identifier=user_id) - return user.to_pydantic() - except NoResultFound: - raise ValueError(f"User with id {user_id} not found.") + user = UserModel.read(db_session=session, identifier=user_id) + return user.to_pydantic() @enforce_types def get_default_user(self) -> PydanticUser: diff --git a/tests/test_local_client.py b/tests/test_local_client.py index a7297cb2..860dce2c 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -135,8 +135,8 @@ def test_agent_add_remove_tools(client: LocalClient, agent): assert github_tool.id in [t.id for t in tools] assert scrape_website_tool.id in [t.id for t in tools] - # Assert that all combinations of tool_names, tool_user_ids are unique - combinations = [(t.name, t.user_id) for t in tools] + # Assert that all combinations of tool_names, organization id are unique + combinations = [(t.name, t.organization_id) for t in tools] assert len(combinations) == len(set(combinations)) # create agent diff --git a/tests/test_managers.py b/tests/test_managers.py index b6ec71b7..1dd3743d 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -47,18 +47,17 @@ def tool_fixture(server: SyncServer): org = server.organization_manager.create_default_organization() user = server.user_manager.create_default_user() - tool_create = ToolCreate( - user_id=user.id, organization_id=org.id, description=description, tags=tags, source_code=source_code, source_type=source_type - ) + other_user = server.user_manager.create_user(UserCreate(name="other", organization_id=org.id)) + tool_create = ToolCreate(description=description, tags=tags, source_code=source_code, source_type=source_type) derived_json_schema = derive_openai_json_schema(tool_create) derived_name = derived_json_schema["name"] tool_create.json_schema = derived_json_schema tool_create.name = derived_name - tool = server.tool_manager.create_tool(tool_create) + tool = server.tool_manager.create_tool(tool_create, actor=user) # Yield the created tool, organization, and user for use in tests - yield {"tool": tool, "organization": org, "user": user, "tool_create": tool_create} + yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool_create} @pytest.fixture(scope="module") @@ -177,7 +176,7 @@ def test_create_tool(server: SyncServer, tool_fixture): org = tool_fixture["organization"] # Assertions to ensure the created tool matches the expected values - assert tool.user_id == user.id + assert tool.created_by_id == user.id assert tool.organization_id == org.id assert tool.description == tool_create.description assert tool.tags == tool_create.tags @@ -188,9 +187,10 @@ def test_create_tool(server: SyncServer, tool_fixture): def test_get_tool_by_id(server: SyncServer, tool_fixture): tool = tool_fixture["tool"] + user = tool_fixture["user"] # Fetch the tool by ID using the manager method - fetched_tool = server.tool_manager.get_tool_by_id(tool.id) + fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) # Assertions to check if the fetched tool matches the created tool assert fetched_tool.id == tool.id @@ -201,34 +201,17 @@ def test_get_tool_by_id(server: SyncServer, tool_fixture): assert fetched_tool.source_type == tool.source_type -def test_get_tool_by_name_and_org_id(server: SyncServer, tool_fixture): - tool = tool_fixture["tool"] - org = tool_fixture["organization"] - - # Fetch the tool by name and organization ID - fetched_tool = server.tool_manager.get_tool_by_name_and_org_id(tool.name, org.id) - - # Assertions to check if the fetched tool matches the created tool - assert fetched_tool.id == tool.id - assert fetched_tool.name == tool.name - assert fetched_tool.organization_id == org.id - assert fetched_tool.description == tool.description - assert fetched_tool.tags == tool.tags - assert fetched_tool.source_code == tool.source_code - assert fetched_tool.source_type == tool.source_type - - -def test_get_tool_by_name_and_user_id(server: SyncServer, tool_fixture): +def test_get_tool_with_actor(server: SyncServer, tool_fixture): tool = tool_fixture["tool"] user = tool_fixture["user"] # Fetch the tool by name and organization ID - fetched_tool = server.tool_manager.get_tool_by_name_and_user_id(tool.name, user.id) + fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=user) # Assertions to check if the fetched tool matches the created tool assert fetched_tool.id == tool.id assert fetched_tool.name == tool.name - assert fetched_tool.user_id == user.id + assert fetched_tool.created_by_id == user.id assert fetched_tool.description == tool.description assert fetched_tool.tags == tool.tags assert fetched_tool.source_code == tool.source_code @@ -237,10 +220,11 @@ def test_get_tool_by_name_and_user_id(server: SyncServer, tool_fixture): def test_list_tools(server: SyncServer, tool_fixture): tool = tool_fixture["tool"] - org = tool_fixture["organization"] + tool_fixture["organization"] + user = tool_fixture["user"] # List tools (should include the one created by the fixture) - tools = server.tool_manager.list_tools_for_org(organization_id=org.id) + tools = server.tool_manager.list_tools(actor=user) # Assertions to check that the created tool is listed assert len(tools) == 1 @@ -249,27 +233,50 @@ def test_list_tools(server: SyncServer, tool_fixture): def test_update_tool_by_id(server: SyncServer, tool_fixture): tool = tool_fixture["tool"] + user = tool_fixture["user"] updated_description = "updated_description" # Create a ToolUpdate object to modify the tool's description tool_update = ToolUpdate(description=updated_description) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) # Assertions to check if the update was successful assert updated_tool.description == updated_description +def test_update_tool_multi_user(server: SyncServer, tool_fixture): + tool = tool_fixture["tool"] + user = tool_fixture["user"] + other_user = tool_fixture["other_user"] + updated_description = "updated_description" + + # Create a ToolUpdate object to modify the tool's description + tool_update = ToolUpdate(description=updated_description) + + # Update the tool using the manager method, but WITH THE OTHER USER'S ID! + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=other_user) + + # Check that the created_by and last_updated_by fields are correct + + # Fetch the updated tool to verify the changes + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + + assert updated_tool.last_updated_by_id == other_user.id + assert updated_tool.created_by_id == user.id + + def test_delete_tool_by_id(server: SyncServer, tool_fixture): tool = tool_fixture["tool"] - org = tool_fixture["organization"] + tool_fixture["organization"] + user = tool_fixture["user"] # Delete the tool using the manager method - server.tool_manager.delete_tool_by_id(tool.id) + server.tool_manager.delete_tool_by_id(tool.id, actor=user) - tools = server.tool_manager.list_tools_for_org(organization_id=org.id) + tools = server.tool_manager.list_tools(actor=user) assert len(tools) == 0 diff --git a/tests/test_server.py b/tests/test_server.py index 5bbffdd9..529e4397 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -79,7 +79,7 @@ def agent_id(server, user_id): llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config(provider="openai"), ), - user_id=user_id, + actor=server.get_user_or_default(user_id), ) print(f"Created agent\n{agent_state}") yield agent_state.id diff --git a/tests/test_tools.py b/tests/test_tools.py index e2ddf999..c5a6d2ec 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -146,7 +146,7 @@ def test_create_agent_tool(client): # create agent with tool memory = ChatMemory(human="I am a human", persona="You must clear your memory if the human instructs you") agent = client.create_agent(name=test_agent_name, tools=[tool.name], memory=memory) - assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}" + assert str(tool.created_by_id) == str(agent.user_id), f"Expected {tool.created_by_id} to be {agent.user_id}" # initial memory initial_memory = client.get_in_context_memory(agent.id)