From 19591098215b0c75ee9f8a842fd947fd6df9afb0 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Fri, 25 Jul 2025 13:06:12 -0700 Subject: [PATCH] feat: support for project_id and backfills --- ..._support_for_project_id_for_blocks_and_.py | 69 +++++++++++++++++++ letta/orm/agent.py | 5 +- letta/orm/block.py | 4 +- letta/orm/group.py | 4 +- letta/orm/identity.py | 7 +- letta/orm/step.py | 6 +- letta/schemas/block.py | 3 + letta/schemas/group.py | 3 + letta/server/rest_api/routers/v1/blocks.py | 2 + letta/server/rest_api/routers/v1/groups.py | 2 +- letta/server/rest_api/routers/v1/tools.py | 2 +- letta/services/block_manager.py | 4 ++ tests/test_client.py | 5 +- 13 files changed, 98 insertions(+), 18 deletions(-) create mode 100644 alembic/versions/06fbbf65d4f1_support_for_project_id_for_blocks_and_.py diff --git a/alembic/versions/06fbbf65d4f1_support_for_project_id_for_blocks_and_.py b/alembic/versions/06fbbf65d4f1_support_for_project_id_for_blocks_and_.py new file mode 100644 index 00000000..a2c0e062 --- /dev/null +++ b/alembic/versions/06fbbf65d4f1_support_for_project_id_for_blocks_and_.py @@ -0,0 +1,69 @@ +"""support for project_id for blocks and groups + +Revision ID: 06fbbf65d4f1 +Revises: f55542f37641 +Create Date: 2025-07-21 15:07:32.133538 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import text + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "06fbbf65d4f1" +down_revision: Union[str, None] = "f55542f37641" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("block", sa.Column("project_id", sa.String(), nullable=True)) + op.add_column("groups", sa.Column("project_id", sa.String(), nullable=True)) + + # Backfill project_id for blocks table + # Since all agents for a block have the same project_id, we can just grab the first one + op.execute( + text( + """ + UPDATE block + SET project_id = ( + SELECT a.project_id + FROM blocks_agents ba + JOIN agents a ON ba.agent_id = a.id + WHERE ba.block_id = block.id + AND a.project_id IS NOT NULL + LIMIT 1 + ) + """ + ) + ) + + # Backfill project_id for groups table + op.execute( + text( + """ + UPDATE groups + SET project_id = ( + SELECT a.project_id + FROM groups_agents ga + JOIN agents a ON ga.agent_id = a.id + WHERE ga.group_id = groups.id + AND a.project_id IS NOT NULL + LIMIT 1 + ) + """ + ) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("groups", "project_id") + op.drop_column("block", "project_id") + # ### end Alembic commands ### diff --git a/letta/orm/agent.py b/letta/orm/agent.py index bc2879d8..a0128f87 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn from letta.orm.identity import Identity -from letta.orm.mixins import OrganizationMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin from letta.orm.organization import Organization from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.agent import AgentState as PydanticAgentState @@ -31,7 +31,7 @@ if TYPE_CHECKING: from letta.orm.tool import Tool -class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs): +class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, AsyncAttrs): __tablename__ = "agents" __pydantic_model__ = PydanticAgentState __table_args__ = (Index("ix_agents_created_at", "created_at", "id"),) @@ -67,7 +67,6 @@ class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs): embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column( EmbeddingConfigColumn, doc="the embedding configuration object for this agent." ) - project_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the project the agent belongs to.") template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The id of the template the agent belongs to.") base_template_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The base template id of the agent.") diff --git a/letta/orm/block.py b/letta/orm/block.py index bbbb6dfc..a31b5a87 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, attributes, declared_attr, mapped_column, rel from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents -from letta.orm.mixins import OrganizationMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import Human, Persona @@ -16,7 +16,7 @@ if TYPE_CHECKING: from letta.orm.identity import Identity -class Block(OrganizationMixin, SqlalchemyBase): +class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin): """Blocks are sections of the LLM context, representing a specific part of the total Memory""" __tablename__ = "block" diff --git a/letta/orm/group.py b/letta/orm/group.py index 489e563f..fe2cfe7e 100644 --- a/letta/orm/group.py +++ b/letta/orm/group.py @@ -4,12 +4,12 @@ from typing import List, Optional from sqlalchemy import JSON, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.mixins import OrganizationMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.group import Group as PydanticGroup -class Group(SqlalchemyBase, OrganizationMixin): +class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin): __tablename__ = "groups" __pydantic_model__ = PydanticGroup diff --git a/letta/orm/identity.py b/letta/orm/identity.py index ac83e952..0d4f13ca 100644 --- a/letta/orm/identity.py +++ b/letta/orm/identity.py @@ -1,17 +1,17 @@ import uuid -from typing import List, Optional +from typing import List from sqlalchemy import String, UniqueConstraint from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.mixins import OrganizationMixin +from letta.orm.mixins import OrganizationMixin, ProjectMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.identity import Identity as PydanticIdentity from letta.schemas.identity import IdentityProperty -class Identity(SqlalchemyBase, OrganizationMixin): +class Identity(SqlalchemyBase, OrganizationMixin, ProjectMixin): """Identity ORM class""" __tablename__ = "identities" @@ -32,7 +32,6 @@ class Identity(SqlalchemyBase, OrganizationMixin): identifier_key: Mapped[str] = mapped_column(nullable=False, doc="External, user-generated identifier key of the identity.") name: Mapped[str] = mapped_column(nullable=False, doc="The name of the identity.") identity_type: Mapped[str] = mapped_column(nullable=False, doc="The type of the identity.") - project_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The project id of the identity.") properties: Mapped[List["IdentityProperty"]] = mapped_column( JSON, nullable=False, default=list, doc="List of properties associated with the identity" ) diff --git a/letta/orm/step.py b/letta/orm/step.py index 05da631c..e35aa135 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional from sqlalchemy import JSON, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.orm.mixins import ProjectMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.step import Step as PydanticStep @@ -13,7 +14,7 @@ if TYPE_CHECKING: from letta.orm.provider import Provider -class Step(SqlalchemyBase): +class Step(SqlalchemyBase, ProjectMixin): """Tracks all metadata for agent step.""" __tablename__ = "steps" @@ -53,9 +54,6 @@ class Step(SqlalchemyBase): feedback: Mapped[Optional[str]] = mapped_column( None, nullable=True, doc="The feedback for this step. Must be either 'positive' or 'negative'." ) - project_id: Mapped[Optional[str]] = mapped_column( - None, nullable=True, doc="The project that the agent that executed this step belongs to (cloud only)." - ) # Relationships (foreign keys) organization: Mapped[Optional["Organization"]] = relationship("Organization") diff --git a/letta/schemas/block.py b/letta/schemas/block.py index ea30da77..05604575 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -19,6 +19,7 @@ class BaseBlock(LettaBase, validate_assignment=True): value: str = Field(..., description="Value of the block.") limit: int = Field(CORE_MEMORY_BLOCK_CHAR_LIMIT, description="Character limit of the block.") + project_id: Optional[str] = Field(None, description="The associated project id.") # template data (optional) template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name") is_template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).") @@ -112,6 +113,7 @@ class BlockUpdate(BaseBlock): limit: Optional[int] = Field(None, description="Character limit of the block.") value: Optional[str] = Field(None, description="Value of the block.") + project_id: Optional[str] = Field(None, description="The associated project id.") class Config: extra = "ignore" # Ignores extra fields @@ -124,6 +126,7 @@ class CreateBlock(BaseBlock): limit: int = Field(CORE_MEMORY_BLOCK_CHAR_LIMIT, description="Character limit of the block.") value: str = Field(..., description="Value of the block.") + project_id: Optional[str] = Field(None, description="The associated project id.") # block templates is_template: bool = False template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name") diff --git a/letta/schemas/group.py b/letta/schemas/group.py index de40ba5d..fdbfff6d 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -24,6 +24,7 @@ class Group(GroupBase): manager_type: ManagerType = Field(..., description="") agent_ids: List[str] = Field(..., description="") description: str = Field(..., description="") + project_id: Optional[str] = Field(None, description="The associated project id.") shared_block_ids: List[str] = Field([], description="") # Pattern fields manager_agent_id: Optional[str] = Field(None, description="") @@ -138,6 +139,7 @@ class GroupCreate(BaseModel): agent_ids: List[str] = Field(..., description="") description: str = Field(..., description="") manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="") + project_id: Optional[str] = Field(None, description="The associated project id.") shared_block_ids: List[str] = Field([], description="") @@ -145,4 +147,5 @@ class GroupUpdate(BaseModel): agent_ids: Optional[List[str]] = Field(None, description="") description: Optional[str] = Field(None, description="") manager_config: Optional[ManagerConfigUpdateUnion] = Field(None, description="") + project_id: Optional[str] = Field(None, description="The associated project id.") shared_block_ids: Optional[List[str]] = Field(None, description="") diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 2ccdc6f5..0320b832 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -22,6 +22,7 @@ async def list_blocks( name: Optional[str] = Query(None, description="Name of the block"), identity_id: Optional[str] = Query(None, description="Search agents by identifier id"), identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), + project_id: Optional[str] = Query(None, description="Search blocks by project id"), limit: Optional[int] = Query(50, description="Number of blocks to return"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -34,6 +35,7 @@ async def list_blocks( template_name=name, identity_id=identity_id, identifier_keys=identifier_keys, + project_id=project_id, limit=limit, ) diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index bbb51157..14f95115 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -31,12 +31,12 @@ def list_groups( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.group_manager.list_groups( + actor=actor, project_id=project_id, manager_type=manager_type, before=before, after=after, limit=limit, - actor=actor, ) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 354afdc2..1165ebcd 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -703,7 +703,7 @@ async def connect_mcp_server( """ async def oauth_stream_generator( - request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] + request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], ) -> AsyncGenerator[str, None]: client = None oauth_provider = None diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index d8682bd8..c0799026 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -178,6 +178,7 @@ class BlockManager: template_name: Optional[str] = None, identity_id: Optional[str] = None, identifier_keys: Optional[List[str]] = None, + project_id: Optional[str] = None, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 50, @@ -210,6 +211,9 @@ class BlockManager: if template_name: query = query.where(BlockModel.template_name == template_name) + if project_id: + query = query.where(BlockModel.project_id == project_id) + needs_distinct = False if identifier_keys: diff --git a/tests/test_client.py b/tests/test_client.py index 8f6c2a23..77507b0b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -710,7 +710,7 @@ def test_attach_sleeptime_block(client: Letta): sleeptime_id = [id for id in agent_ids if id != agent.id][0] # attach a new block - block = client.blocks.create(label="test", value="test") + block = client.blocks.create(label="test", value="test") # , project_id="test") client.agents.blocks.attach(agent_id=agent.id, block_id=block.id) # verify block is attached to both agents @@ -720,5 +720,8 @@ def test_attach_sleeptime_block(client: Letta): blocks = client.agents.blocks.list(agent_id=sleeptime_id) assert block.id in [b.id for b in blocks] + # blocks = client.blocks.list(project_id="test") + # assert block.id in [b.id for b in blocks] + # cleanup client.agents.delete(agent.id)