From 2bb3baf06096df988f6b7967359bf6d689976a51 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 19 Nov 2024 11:32:33 -0800 Subject: [PATCH] feat: Move blocks to ORM model (#1980) Co-authored-by: Sarah Wooders --- ...7507eab4bb9_migrate_blocks_to_orm_model.py | 74 +++++++++ letta/agent.py | 17 +- letta/cli/cli.py | 3 +- letta/client/client.py | 63 +++---- letta/metadata.py | 156 +----------------- letta/o1_agent.py | 4 +- letta/orm/__init__.py | 1 + letta/orm/block.py | 44 +++++ letta/orm/organization.py | 1 + letta/schemas/block.py | 57 ++++--- letta/server/rest_api/routers/v1/blocks.py | 40 ++--- letta/server/server.py | 91 +--------- letta/services/block_manager.py | 103 ++++++++++++ tests/helpers/endpoints_helper.py | 6 +- tests/integration_test_summarizer.py | 4 +- tests/test_client.py | 6 +- tests/test_local_client.py | 10 +- tests/test_managers.py | 84 +++++++++- tests/test_memory.py | 25 +-- tests/test_tools.py | 40 +---- 20 files changed, 431 insertions(+), 398 deletions(-) create mode 100644 alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py create mode 100644 letta/orm/block.py create mode 100644 letta/services/block_manager.py diff --git a/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py b/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py new file mode 100644 index 00000000..9e7fa270 --- /dev/null +++ b/alembic/versions/f7507eab4bb9_migrate_blocks_to_orm_model.py @@ -0,0 +1,74 @@ +"""Migrate blocks to orm model + +Revision ID: f7507eab4bb9 +Revises: c85a3d07c028 +Create Date: 2024-11-18 15:40:13.149438 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f7507eab4bb9" +down_revision: Union[str, None] = "c85a3d07c028" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("block", sa.Column("is_template", sa.Boolean(), nullable=True)) + # Populate `is_template` column + op.execute( + """ + UPDATE block + SET is_template = COALESCE(template, FALSE) + """ + ) + + # Step 2: Make `is_template` non-nullable + op.alter_column("block", "is_template", nullable=False) + op.add_column("block", sa.Column("organization_id", sa.String(), nullable=True)) + # Populate `organization_id` based on `user_id` + # Use a raw SQL query to update the organization_id + op.execute( + """ + UPDATE block + SET organization_id = users.organization_id + FROM users + WHERE block.user_id = users.id + """ + ) + op.alter_column("block", "organization_id", nullable=False) + op.add_column("block", sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("block", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("block", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("block", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("block", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + op.alter_column("block", "limit", existing_type=sa.BIGINT(), type_=sa.Integer(), nullable=False) + op.drop_index("block_idx_user", table_name="block") + op.create_foreign_key(None, "block", "organizations", ["organization_id"], ["id"]) + op.drop_column("block", "template") + op.drop_column("block", "user_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("block", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=True)) + op.add_column("block", sa.Column("template", sa.BOOLEAN(), autoincrement=False, nullable=True)) + op.drop_constraint(None, "block", type_="foreignkey") + op.create_index("block_idx_user", "block", ["user_id"], unique=False) + op.alter_column("block", "limit", existing_type=sa.Integer(), type_=sa.BIGINT(), nullable=True) + op.drop_column("block", "_last_updated_by_id") + op.drop_column("block", "_created_by_id") + op.drop_column("block", "is_deleted") + op.drop_column("block", "updated_at") + op.drop_column("block", "created_at") + op.drop_column("block", "organization_id") + op.drop_column("block", "is_template") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 50264bc7..5e3fd2df 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -27,6 +27,7 @@ from letta.llm_api.llm_api_tools import create from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.memory import ArchivalMemory, RecallMemory, summarize_messages from letta.metadata import MetadataStore +from letta.orm import User from letta.persistence_manager import LocalStateManager from letta.schemas.agent import AgentState, AgentStepResponse from letta.schemas.block import Block @@ -46,6 +47,7 @@ from letta.schemas.passage import Passage from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics +from letta.services.block_manager import BlockManager from letta.services.source_manager import SourceManager from letta.services.user_manager import UserManager from letta.streaming_interface import StreamingRefreshCLIInterface @@ -234,6 +236,7 @@ class Agent(BaseAgent): # agents can be created from providing agent_state agent_state: AgentState, tools: List[Tool], + user: User, # memory: Memory, # extras messages_total: Optional[int] = None, # TODO remove? @@ -245,6 +248,8 @@ class Agent(BaseAgent): self.agent_state = agent_state assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}" + self.user = user + # link tools self.link_tools(tools) @@ -1221,7 +1226,9 @@ class Agent(BaseAgent): # future if we expect templates to change often. continue block_id = block.get("id") - db_block = ms.get_block(block_id=block_id) + + # TODO: This is really hacky and we should probably figure out how to + db_block = BlockManager().get_block_by_id(block_id=block_id, actor=self.user) if db_block is None: # this case covers if someone has deleted a shared block by interacting # with some other agent. @@ -1598,7 +1605,7 @@ def save_agent(agent: Agent, ms: MetadataStore): # NOTE: we're saving agent memory before persisting the agent to ensure # that allocated block_ids for each memory block are present in the agent model - save_agent_memory(agent=agent, ms=ms) + save_agent_memory(agent=agent) if ms.get_agent(agent_id=agent.agent_state.id): ms.update_agent(agent_state) @@ -1609,7 +1616,7 @@ def save_agent(agent: Agent, ms: MetadataStore): assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" -def save_agent_memory(agent: Agent, ms: MetadataStore): +def save_agent_memory(agent: Agent): """ Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table. @@ -1618,14 +1625,12 @@ def save_agent_memory(agent: Agent, ms: MetadataStore): for block_dict in agent.memory.to_dict()["memory"].values(): # TODO: block creation should happen in one place to enforce these sort of constraints consistently. - if block_dict.get("user_id", None) is None: - block_dict["user_id"] = agent.agent_state.user_id block = Block(**block_dict) # FIXME: should we expect for block values to be None? If not, we need to figure out why that is # the case in some tests, if so we should relax the DB constraint. if block.value is None: block.value = "" - ms.update_or_create_block(block) + BlockManager().create_or_update_block(block, actor=agent.user) def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 56b79fb1..076a179a 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -220,7 +220,7 @@ def run( # create agent tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tools] - letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools) + letta_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools, user=client.user) else: # create new agent # create new agent config: override defaults with args if provided @@ -320,6 +320,7 @@ def run( tools=tools, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, + user=client.user, ) save_agent(agent=letta_agent, ms=ms) typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN) diff --git a/letta/client/client.py b/letta/client/client.py index 65099d57..eeb6a50c 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -12,12 +12,10 @@ from letta.memory import get_memory_functions from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState from letta.schemas.block import ( Block, - CreateBlock, - CreateHuman, - CreatePersona, + BlockCreate, + BlockUpdate, Human, Persona, - UpdateBlock, UpdateHuman, UpdatePersona, ) @@ -883,8 +881,8 @@ class RESTClient(AbstractClient): else: return [Block(**block) for block in response.json()] - def create_block(self, label: str, text: str, template_name: Optional[str] = None, template: bool = False) -> Block: # - request = CreateBlock(label=label, value=text, template=template, template_name=template_name) + def create_block(self, label: str, value: str, template_name: Optional[str] = None, is_template: bool = False) -> Block: # + request = BlockCreate(label=label, value=value, template=is_template, template_name=template_name) response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create block: {response.text}") @@ -896,7 +894,7 @@ class RESTClient(AbstractClient): return Block(**response.json()) def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block: - request = UpdateBlock(id=block_id, template_name=name, value=text) + request = BlockUpdate(id=block_id, template_name=name, value=text) response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update block: {response.text}") @@ -950,7 +948,7 @@ class RESTClient(AbstractClient): Returns: human (Human): Human block """ - return self.create_block(label="human", template_name=name, text=text, template=True) + return self.create_block(label="human", template_name=name, value=text, is_template=True) def update_human(self, human_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Human: """ @@ -990,7 +988,7 @@ class RESTClient(AbstractClient): Returns: persona (Persona): Persona block """ - return self.create_block(label="persona", template_name=name, text=text, template=True) + return self.create_block(label="persona", template_name=name, value=text, is_template=True) def update_persona(self, persona_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Persona: """ @@ -2125,8 +2123,7 @@ class LocalClient(AbstractClient): # humans / personas def get_block_id(self, name: str, label: str) -> str: - - block = self.server.get_blocks(name=name, label=label, user_id=self.user_id, template=True) + block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True) if not block: return None return block[0].id @@ -2142,7 +2139,7 @@ class LocalClient(AbstractClient): Returns: human (Human): Human block """ - return self.server.create_block(CreateHuman(template_name=name, value=text, user_id=self.user_id), user_id=self.user_id) + return self.server.block_manager.create_or_update_block(Human(template_name=name, value=text), actor=self.user) def create_persona(self, name: str, text: str): """ @@ -2155,7 +2152,7 @@ class LocalClient(AbstractClient): Returns: persona (Persona): Persona block """ - return self.server.create_block(CreatePersona(template_name=name, value=text, user_id=self.user_id), user_id=self.user_id) + return self.server.block_manager.create_or_update_block(Persona(template_name=name, value=text), actor=self.user) def list_humans(self): """ @@ -2164,7 +2161,7 @@ class LocalClient(AbstractClient): Returns: humans (List[Human]): List of human blocks """ - return self.server.get_blocks(label="human", user_id=self.user_id, template=True) + return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True) def list_personas(self) -> List[Persona]: """ @@ -2173,7 +2170,7 @@ class LocalClient(AbstractClient): Returns: personas (List[Persona]): List of persona blocks """ - return self.server.get_blocks(label="persona", user_id=self.user_id, template=True) + return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True) def update_human(self, human_id: str, text: str): """ @@ -2186,7 +2183,9 @@ class LocalClient(AbstractClient): Returns: human (Human): Updated human block """ - return self.server.update_block(UpdateHuman(id=human_id, value=text, user_id=self.user_id, template=True)) + return self.server.block_manager.update_block( + block_id=human_id, block_update=UpdateHuman(value=text, is_template=True), actor=self.user + ) def update_persona(self, persona_id: str, text: str): """ @@ -2199,7 +2198,9 @@ class LocalClient(AbstractClient): Returns: persona (Persona): Updated persona block """ - return self.server.update_block(UpdatePersona(id=persona_id, value=text, user_id=self.user_id, template=True)) + return self.server.block_manager.update_block( + block_id=persona_id, block_update=UpdatePersona(value=text, is_template=True), actor=self.user + ) def get_persona(self, id: str) -> Persona: """ @@ -2212,7 +2213,7 @@ class LocalClient(AbstractClient): persona (Persona): Persona block """ assert id, f"Persona ID must be provided" - return Persona(**self.server.get_block(id).model_dump()) + return Persona(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) def get_human(self, id: str) -> Human: """ @@ -2225,7 +2226,7 @@ class LocalClient(AbstractClient): human (Human): Human block """ assert id, f"Human ID must be provided" - return Human(**self.server.get_block(id).model_dump()) + return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) def get_persona_id(self, name: str) -> str: """ @@ -2237,7 +2238,7 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the persona block """ - persona = self.server.get_blocks(name=name, label="persona", user_id=self.user_id, template=True) + persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True) if not persona: return None return persona[0].id @@ -2252,7 +2253,7 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the human block """ - human = self.server.get_blocks(name=name, label="human", user_id=self.user_id, template=True) + human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True) if not human: return None return human[0].id @@ -2264,7 +2265,7 @@ class LocalClient(AbstractClient): Args: id (str): ID of the persona block """ - self.server.delete_block(id) + self.delete_block(id) def delete_human(self, id: str): """ @@ -2273,7 +2274,7 @@ class LocalClient(AbstractClient): Args: id (str): ID of the human block """ - self.server.delete_block(id) + self.delete_block(id) # tools def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: @@ -2661,9 +2662,9 @@ class LocalClient(AbstractClient): Returns: blocks (List[Block]): List of blocks """ - return self.server.get_blocks(label=label, template=templates_only) + return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only) - def create_block(self, label: str, text: str, template_name: Optional[str] = None, template: bool = False) -> Block: # + def create_block(self, label: str, value: str, template_name: Optional[str] = None, is_template: bool = False) -> Block: # """ Create a block @@ -2675,8 +2676,8 @@ class LocalClient(AbstractClient): Returns: block (Block): Created block """ - return self.server.create_block( - CreateBlock(label=label, template_name=template_name, value=text, user_id=self.user_id, template=template), user_id=self.user_id + return self.server.block_manager.create_or_update_block( + Block(label=label, template_name=template_name, value=value, is_template=is_template), actor=self.user ) def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block: @@ -2691,7 +2692,9 @@ class LocalClient(AbstractClient): Returns: block (Block): Updated block """ - return self.server.update_block(UpdateBlock(id=block_id, template_name=name, value=text)) + return self.server.block_manager.update_block( + block_id=block_id, block_update=BlockUpdate(template_name=name, value=text), actor=self.user + ) def get_block(self, block_id: str) -> Block: """ @@ -2703,7 +2706,7 @@ class LocalClient(AbstractClient): Returns: block (Block): Block """ - return self.server.get_block(block_id) + return self.server.block_manager.get_block_by_id(block_id, actor=self.user) def delete_block(self, id: str) -> Block: """ @@ -2715,7 +2718,7 @@ class LocalClient(AbstractClient): Returns: block (Block): Deleted block """ - return self.server.delete_block(id) + return self.server.block_manager.delete_block(id, actor=self.user) def set_default_llm_config(self, llm_config: LLMConfig): """ diff --git a/letta/metadata.py b/letta/metadata.py index dc87d032..449cf65b 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -4,23 +4,13 @@ import os import secrets from typing import List, Optional -from sqlalchemy import ( - BIGINT, - JSON, - Boolean, - Column, - DateTime, - Index, - String, - TypeDecorator, -) +from sqlalchemy import JSON, Column, DateTime, Index, String, TypeDecorator from sqlalchemy.sql import func from letta.config import LettaConfig from letta.orm.base import Base from letta.schemas.agent import AgentState from letta.schemas.api_key import APIKey -from letta.schemas.block import Block, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus from letta.schemas.job import Job @@ -269,63 +259,6 @@ class AgentSourceMappingModel(Base): return f"" -class BlockModel(Base): - __tablename__ = "block" - __table_args__ = {"extend_existing": True} - - id = Column(String, primary_key=True, nullable=False) - value = Column(String, nullable=False) - limit = Column(BIGINT) - template_name = Column(String, nullable=True, default=None) - template = Column(Boolean, default=False) # True: listed as possible human/persona - label = Column(String, nullable=False) - metadata_ = Column(JSON) - description = Column(String) - user_id = Column(String) - Index(__tablename__ + "_idx_user", user_id), - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Block: - if self.label == "persona": - return Persona( - id=self.id, - value=self.value, - limit=self.limit, - template_name=self.template_name, - template=self.template, - label=self.label, - metadata_=self.metadata_, - description=self.description, - user_id=self.user_id, - ) - elif self.label == "human": - return Human( - id=self.id, - value=self.value, - limit=self.limit, - template_name=self.template_name, - template=self.template, - label=self.label, - metadata_=self.metadata_, - description=self.description, - user_id=self.user_id, - ) - else: - return Block( - id=self.id, - value=self.value, - limit=self.limit, - template_name=self.template_name, - template=self.template, - label=self.label, - metadata_=self.metadata_, - description=self.description, - user_id=self.user_id, - ) - - class JobModel(Base): __tablename__ = "jobs" __table_args__ = {"extend_existing": True} @@ -425,27 +358,6 @@ class MetadataStore: session.add(AgentModel(**fields)) session.commit() - @enforce_types - def create_block(self, block: Block): - with self.session_maker() as session: - # TODO: fix? - # we are only validating that more than one template block - # with a given name doesn't exist. - if ( - session.query(BlockModel) - .filter(BlockModel.template_name == block.template_name) - .filter(BlockModel.user_id == block.user_id) - .filter(BlockModel.template == True) - .filter(BlockModel.label == block.label) - .count() - > 0 - ): - - raise ValueError(f"Block with name {block.template_name} already exists") - - session.add(BlockModel(**vars(block))) - session.commit() - @enforce_types def update_agent(self, agent: AgentState): with self.session_maker() as session: @@ -457,28 +369,6 @@ class MetadataStore: session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields) session.commit() - @enforce_types - def update_block(self, block: Block): - with self.session_maker() as session: - session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block)) - session.commit() - - @enforce_types - def update_or_create_block(self, block: Block): - with self.session_maker() as session: - existing_block = session.query(BlockModel).filter(BlockModel.id == block.id).first() - if existing_block: - session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block)) - else: - session.add(BlockModel(**vars(block))) - session.commit() - - @enforce_types - def delete_block(self, block_id: str): - with self.session_maker() as session: - session.query(BlockModel).filter(BlockModel.id == block_id).delete() - session.commit() - @enforce_types def delete_agent(self, agent_id: str): with self.session_maker() as session: @@ -513,50 +403,6 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result return results[0].to_record() - @enforce_types - def get_block(self, block_id: str) -> Optional[Block]: - with self.session_maker() as session: - results = session.query(BlockModel).filter(BlockModel.id == block_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - - @enforce_types - def get_blocks( - self, - user_id: Optional[str], - label: Optional[str] = None, - template: Optional[bool] = None, - template_name: Optional[str] = None, - id: Optional[str] = None, - ) -> Optional[List[Block]]: - """List available blocks""" - with self.session_maker() as session: - query = session.query(BlockModel) - - if user_id: - query = query.filter(BlockModel.user_id == user_id) - - if label: - query = query.filter(BlockModel.label == label) - - if template_name: - query = query.filter(BlockModel.template_name == template_name) - - if id: - query = query.filter(BlockModel.id == id) - - if template: - query = query.filter(BlockModel.template == template) - - results = query.all() - - if len(results) == 0: - return None - - return [r.to_record() for r in results] - # agent source metadata @enforce_types def attach_source(self, user_id: str, agent_id: str, source_id: str): diff --git a/letta/o1_agent.py b/letta/o1_agent.py index b1aadec4..9539e4af 100644 --- a/letta/o1_agent.py +++ b/letta/o1_agent.py @@ -8,6 +8,7 @@ from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.tool import Tool from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.user import User def send_thinking_message(self: "Agent", message: str) -> Optional[str]: @@ -43,11 +44,12 @@ class O1Agent(Agent): self, interface: AgentInterface, agent_state: AgentState, + user: User, tools: List[Tool] = [], max_thinking_steps: int = 10, first_message_verify_mono: bool = False, ): - super().__init__(interface, agent_state, tools) + super().__init__(interface, agent_state, tools, user) self.max_thinking_steps = max_thinking_steps self.tools = tools self.first_message_verify_mono = first_message_verify_mono diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 733ce816..1b1df149 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,4 +1,5 @@ from letta.orm.base import Base +from letta.orm.block import Block from letta.orm.file import FileMetadata from letta.orm.organization import Organization from letta.orm.source import Source diff --git a/letta/orm/block.py b/letta/orm/block.py new file mode 100644 index 00000000..f91b4ba7 --- /dev/null +++ b/letta/orm/block.py @@ -0,0 +1,44 @@ +from typing import TYPE_CHECKING, Optional, Type + +from sqlalchemy import JSON, BigInteger, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.block import Block as PydanticBlock +from letta.schemas.block import Human, Persona + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class Block(OrganizationMixin, SqlalchemyBase): + """Blocks are sections of the LLM context, representing a specific part of the total Memory""" + + __tablename__ = "block" + __pydantic_model__ = PydanticBlock + + template_name: Mapped[Optional[str]] = mapped_column( + nullable=True, doc="the unique name that identifies a block in a human-readable way" + ) + description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="a description of the block for context") + label: Mapped[str] = mapped_column(doc="the type of memory block in use, ie 'human', 'persona', 'system'") + is_template: Mapped[bool] = mapped_column( + doc="whether the block is a template (e.g. saved human/persona options as baselines for other templates)", default=False + ) + value: Mapped[str] = mapped_column(doc="Text content of the block for the respective section of core memory.") + limit: Mapped[BigInteger] = mapped_column(Integer, default=2000, doc="Character limit of the block.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default={}, doc="arbitrary information related to the block.") + + # relationships + organization: Mapped[Optional["Organization"]] = relationship("Organization") + + def to_pydantic(self) -> Type: + match self.label: + case "human": + Schema = Human + case "persona": + Schema = Persona + case _: + Schema = PydanticBlock + return Schema.model_validate(self) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 9cfdfb92..c4a059c5 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -23,6 +23,7 @@ class Organization(SqlalchemyBase): # relationships users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") + blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan") sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan") files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 7c0ef9e7..eb516aba 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -14,36 +14,30 @@ class BaseBlock(LettaBase, validate_assignment=True): __id_prefix__ = "block" # data value - value: Optional[str] = Field(None, description="Value of the block.") + value: str = Field(..., description="Value of the block.") limit: int = Field(2000, description="Character limit of the block.") # template data (optional) template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name") - template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).") + is_template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).") # context window label - label: str = Field(None, description="Label of the block (e.g. 'human', 'persona') in the context window.") + label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona') in the context window.") # metadata description: Optional[str] = Field(None, description="Description of the block.") metadata_: Optional[dict] = Field({}, description="Metadata of the block.") - # associated user/agent - user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the block.") - @model_validator(mode="after") def verify_char_limit(self) -> Self: - try: - assert len(self) <= self.limit - except AssertionError: - error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self)}) - {str(self)}." + if len(self.value) > self.limit: + error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}." raise ValueError(error_msg) - except Exception as e: - raise e + return self - def __len__(self): - return len(self.value) + # def __len__(self): + # return len(self.value) def __setattr__(self, name, value): """Run validation if self.value is updated""" @@ -52,6 +46,9 @@ class BaseBlock(LettaBase, validate_assignment=True): # run validation self.__class__.model_validate(self.model_dump(exclude_unset=True)) + class Config: + extra = "ignore" # Ignores extra fields + class Block(BaseBlock): """ @@ -61,15 +58,22 @@ class Block(BaseBlock): label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block. value (str): The value of the block. This is the string that is represented in the context window. limit (int): The character limit of the block. + is_template (bool): Whether the block is a template (e.g. saved human/persona options). Non-template blocks are not stored in the database and are ephemeral, while templated blocks are stored in the database. + label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block. template_name (str): The name of the block template (if it is a template). - template (bool): Whether the block is a template (e.g. saved human/persona options). Non-template blocks are not stored in the database and are ephemeral, while templated blocks are stored in the database. description (str): Description of the block. metadata_ (Dict): Metadata of the block. user_id (str): The unique identifier of the user associated with the block. """ id: str = BaseBlock.generate_id_field() - value: str = Field(..., description="Value of the block.") + + # associated user/agent + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the block.") + + # default orm fields + created_by_id: Optional[str] = Field(None, description="The id of the user that made this Block.") + last_updated_by_id: Optional[str] = Field(None, description="The id of the user that last updated this Block.") class Human(Block): @@ -84,41 +88,42 @@ class Persona(Block): label: str = "persona" -class CreateBlock(BaseBlock): +class BlockCreate(BaseBlock): """Create a block""" - template: bool = True + is_template: bool = True label: str = Field(..., description="Label of the block.") -class CreatePersona(BaseBlock): +class CreatePersona(BlockCreate): """Create a persona block""" - template: bool = True label: str = "persona" -class CreateHuman(BaseBlock): +class CreateHuman(BlockCreate): """Create a human block""" - template: bool = True label: str = "human" -class UpdateBlock(BaseBlock): +class BlockUpdate(BaseBlock): """Update a block""" - id: str = Field(..., description="The unique identifier of the block.") limit: Optional[int] = Field(2000, description="Character limit of the block.") + value: Optional[str] = Field(None, description="Value of the block.") + + class Config: + extra = "ignore" # Ignores extra fields -class UpdatePersona(UpdateBlock): +class UpdatePersona(BlockUpdate): """Update a persona block""" label: str = "persona" -class UpdateHuman(UpdateBlock): +class UpdateHuman(BlockUpdate): """Update a human block""" label: str = "human" diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 74dc76da..6fee08dd 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query -from letta.schemas.block import Block, CreateBlock, UpdateBlock +from letta.orm.errors import NoResultFound +from letta.schemas.block import Block, BlockCreate, BlockUpdate from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -22,54 +23,49 @@ def list_blocks( user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id=user_id) - - blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name) - if blocks is None: - return [] - return blocks + return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name) @router.post("/", response_model=Block, operation_id="create_memory_block") def create_block( - create_block: CreateBlock = Body(...), + create_block: BlockCreate = Body(...), server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): actor = server.get_user_or_default(user_id=user_id) - - create_block.user_id = actor.id - return server.create_block(user_id=actor.id, request=create_block) + block = Block(**create_block.model_dump()) + return server.block_manager.create_or_update_block(actor=actor, block=block) @router.patch("/{block_id}", response_model=Block, operation_id="update_memory_block") def update_block( block_id: str, - updated_block: UpdateBlock = Body(...), + updated_block: BlockUpdate = Body(...), server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), ): - # actor = server.get_current_user() - - updated_block.id = block_id - return server.update_block(request=updated_block) + actor = server.get_user_or_default(user_id=user_id) + return server.block_manager.update_block(block_id=block_id, block_update=updated_block, actor=actor) -# TODO: delete should not return anything @router.delete("/{block_id}", response_model=Block, operation_id="delete_memory_block") def delete_block( block_id: str, server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), ): - - return server.delete_block(block_id=block_id) + actor = server.get_user_or_default(user_id=user_id) + return server.block_manager.delete_block(block_id=block_id, actor=actor) @router.get("/{block_id}", response_model=Block, operation_id="get_memory_block") def get_block( block_id: str, server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), ): - - block = server.get_block(block_id=block_id) - if block is None: + actor = server.get_user_or_default(user_id=user_id) + try: + return server.block_manager.get_block_by_id(block_id=block_id, actor=actor) + except NoResultFound: raise HTTPException(status_code=404, detail="Block not found") - return block diff --git a/letta/server/server.py b/letta/server/server.py index 47df6c2e..4c61cffa 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -55,13 +55,6 @@ from letta.providers import ( ) from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState from letta.schemas.api_key import APIKey, APIKeyCreate -from letta.schemas.block import ( - Block, - CreateBlock, - CreateHuman, - CreatePersona, - UpdateBlock, -) from letta.schemas.embedding_config import EmbeddingConfig # openai schemas @@ -83,6 +76,7 @@ from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.agents_tags_manager import AgentsTagsManager +from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager @@ -250,6 +244,7 @@ class SyncServer(Server): self.organization_manager = OrganizationManager() self.user_manager = UserManager() self.tool_manager = ToolManager() + self.block_manager = BlockManager() self.source_manager = SourceManager() self.agents_tags_manager = AgentsTagsManager() @@ -257,7 +252,7 @@ class SyncServer(Server): if init_with_default_org_and_user: self.default_org = self.organization_manager.create_default_organization() self.default_user = self.user_manager.create_default_user() - self.add_default_blocks(self.default_user.id) + self.block_manager.add_default_blocks(actor=self.default_user) self.tool_manager.add_base_tools(actor=self.default_user) # If there is a default org/user @@ -333,15 +328,6 @@ class SyncServer(Server): ) ) - def save_agents(self): - """Saves all the agents that are in the in-memory object store""" - for agent_d in self.active_agents: - try: - save_agent(agent_d["agent"], self.ms) - logger.debug(f"Saved agent {agent_d['agent_id']}") - except Exception as e: - logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}") - def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]: """Get the agent object from the in-memory object store""" for d in self.active_agents: @@ -399,9 +385,9 @@ class SyncServer(Server): assert isinstance(agent_state.memory, Memory) if agent_state.agent_type == AgentType.memgpt_agent: - letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) + letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs, user=actor) elif agent_state.agent_type == AgentType.o1_agent: - letta_agent = O1Agent(agent_state=agent_state, interface=interface, tools=tool_objs) + letta_agent = O1Agent(agent_state=agent_state, interface=interface, tools=tool_objs, user=actor) else: raise NotImplementedError("Not a supported agent type") @@ -884,6 +870,7 @@ class SyncServer(Server): first_message_verify_mono=( True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False ), + user=actor, initial_message_sequence=request.initial_message_sequence, ) elif request.agent_type == AgentType.o1_agent: @@ -895,6 +882,7 @@ class SyncServer(Server): first_message_verify_mono=( True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False ), + user=actor, ) # rebuilding agent memory on agent create in case shared memory blocks # were specified in the new agent's memory config. we're doing this for two reasons: @@ -1130,56 +1118,6 @@ class SyncServer(Server): return [self.get_agent_state(user_id=user.id, agent_id=agent_id) for agent_id in agent_ids] - def get_blocks( - self, - user_id: Optional[str] = None, - label: Optional[str] = None, - template: Optional[bool] = None, - name: Optional[str] = None, - id: Optional[str] = None, - ) -> Optional[List[Block]]: - - return self.ms.get_blocks(user_id=user_id, label=label, template=template, template_name=name, id=id) - - def get_block(self, block_id: str): - - blocks = self.get_blocks(id=block_id) - if blocks is None or len(blocks) == 0: - raise ValueError("Block does not exist") - if len(blocks) > 1: - raise ValueError("Multiple blocks with the same id") - return blocks[0] - - def create_block(self, request: CreateBlock, user_id: str, update: bool = False) -> Block: - existing_blocks = self.ms.get_blocks( - template_name=request.template_name, user_id=user_id, template=request.template, label=request.label - ) - - # for templates, update existing block template if exists - if existing_blocks is not None and request.template: - existing_block = existing_blocks[0] - assert len(existing_blocks) == 1 - if update: - return self.update_block(UpdateBlock(id=existing_block.id, **vars(request))) - else: - raise ValueError(f"Block with name {request.template_name} already exists") - block = Block(**vars(request)) - self.ms.create_block(block) - return block - - def update_block(self, request: UpdateBlock) -> Block: - block = self.get_block(request.id) - block.limit = request.limit if request.limit is not None else block.limit - block.value = request.value if request.value is not None else block.value - block.template_name = request.template_name if request.template_name is not None else block.template_name - self.ms.update_block(block=block) - return self.ms.get_block(block_id=request.id) - - def delete_block(self, block_id: str): - block = self.get_block(block_id) - self.ms.delete_block(block_id) - return block - # convert name->id def get_agent_id(self, name: str, user_id: str): @@ -1790,21 +1728,6 @@ class SyncServer(Server): return success - def add_default_blocks(self, user_id: str): - from letta.utils import list_human_files, list_persona_files - - assert user_id is not None, "User ID must be provided" - - for persona_file in list_persona_files(): - text = open(persona_file, "r", encoding="utf-8").read() - name = os.path.basename(persona_file).replace(".txt", "") - self.create_block(CreatePersona(user_id=user_id, template_name=name, value=text, template=True), user_id=user_id, update=True) - - for human_file in list_human_files(): - text = open(human_file, "r", encoding="utf-8").read() - name = os.path.basename(human_file).replace(".txt", "") - self.create_block(CreateHuman(user_id=user_id, template_name=name, value=text, template=True), user_id=user_id, update=True) - def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]: """Get a single message from the agent's memory""" # Get the agent object (loaded in memory) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py new file mode 100644 index 00000000..87541434 --- /dev/null +++ b/letta/services/block_manager.py @@ -0,0 +1,103 @@ +import os +from typing import List, Optional + +from letta.orm.block import Block as BlockModel +from letta.orm.errors import NoResultFound +from letta.schemas.block import Block +from letta.schemas.block import Block as PydanticBlock +from letta.schemas.block import BlockUpdate, Human, Persona +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types, list_human_files, list_persona_files + + +class BlockManager: + """Manager class to handle business logic related to Blocks.""" + + def __init__(self): + # Fetching the db_context similarly as in ToolManager + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticBlock: + """Create a new block based on the Block schema.""" + db_block = self.get_block_by_id(block.id, actor) + if db_block: + update_data = BlockUpdate(**block.model_dump(exclude_none=True)) + self.update_block(block.id, update_data, actor) + else: + with self.session_maker() as session: + data = block.model_dump(exclude_none=True) + block = BlockModel(**data, organization_id=actor.organization_id) + block.create(session, actor=actor) + return block.to_pydantic() + + @enforce_types + def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: + """Update a block by its ID with the given BlockUpdate object.""" + with self.session_maker() as session: + block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + update_data = block_update.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(block, key, value) + block.update(db_session=session, actor=actor) + return block.to_pydantic() + + @enforce_types + def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: + """Delete a block by its ID.""" + with self.session_maker() as session: + block = BlockModel.read(db_session=session, identifier=block_id) + block.hard_delete(db_session=session, actor=actor) + return block.to_pydantic() + + @enforce_types + def get_blocks( + self, + actor: PydanticUser, + label: Optional[str] = None, + is_template: Optional[bool] = None, + template_name: Optional[str] = None, + id: Optional[str] = None, + cursor: Optional[str] = None, + limit: Optional[int] = 50, + ) -> List[PydanticBlock]: + """Retrieve blocks based on various optional filters.""" + with self.session_maker() as session: + # Prepare filters + filters = {"organization_id": actor.organization_id} + if label: + filters["label"] = label + if is_template is not None: + filters["is_template"] = is_template + if template_name: + filters["template_name"] = template_name + if id: + filters["id"] = id + + blocks = BlockModel.list(db_session=session, cursor=cursor, limit=limit, **filters) + + return [block.to_pydantic() for block in blocks] + + @enforce_types + def get_block_by_id(self, block_id, actor: PydanticUser) -> Optional[PydanticBlock]: + """Retrieve a block by its name.""" + with self.session_maker() as session: + try: + block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + return block.to_pydantic() + except NoResultFound: + return None + + @enforce_types + def add_default_blocks(self, actor: PydanticUser): + for persona_file in list_persona_files(): + text = open(persona_file, "r", encoding="utf-8").read() + name = os.path.basename(persona_file).replace(".txt", "") + self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor) + + for human_file in list_human_files(): + text = open(human_file, "r", encoding="utf-8").read() + name = os.path.basename(human_file).replace(".txt", "") + self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 2f32e67d..d75301a6 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -104,11 +104,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet agent_state = setup_agent(client, filename) tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] - agent = Agent( - interface=None, - tools=tools, - agent_state=agent_state, - ) + agent = Agent(interface=None, tools=tools, agent_state=agent_state, user=client.user) response = create( llm_config=agent_state.llm_config, diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 6fc73b47..622ef4b6 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -46,7 +46,9 @@ def test_summarizer(config_filename): # Create agent agent_state = client.create_agent(name=agent_name, llm_config=llm_config, embedding_config=embedding_config) tools = [client.get_tool(client.get_tool_id(name=tool_name)) for tool_name in agent_state.tools] - letta_agent = Agent(interface=StreamingRefreshCLIInterface(), agent_state=agent_state, tools=tools, first_message_verify_mono=False) + letta_agent = Agent( + interface=StreamingRefreshCLIInterface(), agent_state=agent_state, tools=tools, first_message_verify_mono=False, user=client.user + ) # Make conversation messages = [ diff --git a/tests/test_client.py b/tests/test_client.py index f7cfaaed..7f5f095c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -588,13 +588,13 @@ def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState # _reset_config() # create a block - block = client.create_block(label="human", text="username: sarah") + block = client.create_block(label="human", value="username: sarah") # create agents with shared block from letta.schemas.memory import BasicBlockMemory - persona1_block = client.create_block(label="persona", text="you are agent 1") - persona2_block = client.create_block(label="persona", text="you are agent 2") + persona1_block = client.create_block(label="persona", value="you are agent 1") + persona2_block = client.create_block(label="persona", value="you are agent 2") # create agnets agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory(blocks=[block, persona1_block])) diff --git a/tests/test_local_client.py b/tests/test_local_client.py index 2ffd26f7..d8518a35 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -5,7 +5,6 @@ import pytest from letta import create_client from letta.client.client import LocalClient from letta.schemas.agent import AgentState -from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory @@ -53,7 +52,7 @@ def test_agent(client: LocalClient): agent_state = client.get_agent(agent_state_test.id) assert agent_state.name == "test_agent2" for block in agent_state.memory.to_dict()["memory"].values(): - db_block = client.server.ms.get_block(block.get("id")) + db_block = client.server.block_manager.get_block_by_id(block.get("id"), actor=client.user) assert db_block is not None, "memory block not persisted on agent create" assert db_block.value == block.get("value"), "persisted block data does not match in-memory data" @@ -169,12 +168,9 @@ def test_agent_add_remove_tools(client: LocalClient, agent): def test_agent_with_shared_blocks(client: LocalClient): - persona_block = Block(template_name="persona", value="Here to test things!", label="persona", user_id=client.user_id) - human_block = Block(template_name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id) + persona_block = client.create_block(template_name="persona", value="Here to test things!", label="persona") + human_block = client.create_block(template_name="human", value="Me Human, I swear. Beep boop.", label="human") existing_non_template_blocks = [persona_block, human_block] - for block in existing_non_template_blocks: - # ensure that previous chat blocks are persisted, as if another agent already produced them. - client.server.ms.create_block(block) existing_non_template_blocks_no_values = [] for block in existing_non_template_blocks: diff --git a/tests/test_managers.py b/tests/test_managers.py index a5d8528a..6620cff8 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,8 +3,10 @@ from sqlalchemy import delete import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.orm import FileMetadata, Organization, Source, Tool, User +from letta.orm import Block, FileMetadata, Organization, Source, Tool, User from letta.schemas.agent import CreateAgent +from letta.schemas.block import Block as PydanticBlock +from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.llm_config import LLMConfig @@ -14,6 +16,7 @@ from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate +from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager utils.DEBUG = True @@ -38,6 +41,7 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(Block)) session.execute(delete(FileMetadata)) session.execute(delete(Source)) session.execute(delete(Tool)) # Clear all records from the Tool table @@ -433,6 +437,84 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): assert len(tools) == 0 +# ====================================================================================================================== +# Block Manager Tests +# ====================================================================================================================== + + +def test_create_block(server: SyncServer, default_user): + block_manager = BlockManager() + block_create = PydanticBlock( + label="human", + is_template=True, + value="Sample content", + template_name="sample_template", + description="A test block", + limit=1000, + metadata_={"example": "data"}, + ) + + block = block_manager.create_or_update_block(block_create, actor=default_user) + + # Assertions to ensure the created block matches the expected values + assert block.label == block_create.label + assert block.is_template == block_create.is_template + assert block.value == block_create.value + assert block.template_name == block_create.template_name + assert block.description == block_create.description + assert block.limit == block_create.limit + assert block.metadata_ == block_create.metadata_ + assert block.organization_id == default_user.organization_id + + +def test_get_blocks(server, default_user): + block_manager = BlockManager() + + # Create blocks to retrieve later + block_manager.create_or_update_block(PydanticBlock(label="human", value="Block 1"), actor=default_user) + block_manager.create_or_update_block(PydanticBlock(label="persona", value="Block 2"), actor=default_user) + + # Retrieve blocks by different filters + all_blocks = block_manager.get_blocks(actor=default_user) + assert len(all_blocks) == 2 + + human_blocks = block_manager.get_blocks(actor=default_user, label="human") + assert len(human_blocks) == 1 + assert human_blocks[0].label == "human" + + persona_blocks = block_manager.get_blocks(actor=default_user, label="persona") + assert len(persona_blocks) == 1 + assert persona_blocks[0].label == "persona" + + +def test_update_block(server: SyncServer, default_user): + block_manager = BlockManager() + block = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user) + + # Update block's content + update_data = BlockUpdate(value="Updated Content", description="Updated description") + block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user) + + # Retrieve the updated block + updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0] + + # Assertions to verify the update + assert updated_block.value == "Updated Content" + assert updated_block.description == "Updated description" + + +def test_delete_block(server: SyncServer, default_user): + block_manager = BlockManager() + + # Create and delete a block + block = block_manager.create_or_update_block(PydanticBlock(label="human", value="Sample content"), actor=default_user) + block_manager.delete_block(block_id=block.id, actor=default_user) + + # Verify that the block was deleted + blocks = block_manager.get_blocks(actor=default_user) + assert len(blocks) == 0 + + # ====================================================================================================================== # Source Manager Tests - Sources # ====================================================================================================================== diff --git a/tests/test_memory.py b/tests/test_memory.py index 91a29f3f..3760f31a 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -33,29 +33,6 @@ def test_load_memory_from_json(sample_memory: Memory): assert new_memory.get_block("human").value == "User" -# def test_memory_functionality(sample_memory): -# """Test memory modification functions""" -# # Get memory functions -# functions = get_memory_functions(ChatMemory) -# # Test core_memory_append function -# append_func = functions['core_memory_append'] -# print("FUNCTIONS", functions) -# env = {} -# env.update(globals()) -# for tool in functions: -# # WARNING: name may not be consistent? -# exec(tool.source_code, env) -# -# print(exec) -# -# append_func(sample_memory, 'persona', " is a test.") -# assert sample_memory.memory['persona'].value == "Chat Agent\n is a test." -# # Test core_memory_replace function -# replace_func = functions['core_memory_replace'] -# replace_func(sample_memory, 'persona', " is a test.", " was a test.") -# assert sample_memory.memory['persona'].value == "Chat Agent\n was a test." - - def test_memory_limit_validation(sample_memory: Memory): """Test exceeding memory limit""" with pytest.raises(ValueError): @@ -89,7 +66,7 @@ def test_memory_jinja2_template(sample_memory: Memory): """Generate a string representation of the memory in-context""" section_strs = [] for section, module in self.memory.items(): - section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n') + section_strs.append(f'<{section} characters="{len(module.value)}/{module.limit}">\n{module.value}\n') return "\n".join(section_strs) old_repr_str = old_repr(sample_memory) diff --git a/tests/test_tools.py b/tests/test_tools.py index 124520ec..3515eca4 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,6 +1,3 @@ -import os -import threading -import time import uuid from typing import Union @@ -37,33 +34,12 @@ def run_server(): start_server(debug=True) -# Fixture to create clients with different configurations -@pytest.fixture( - # params=[{"server": True}, {"server": False}], # whether to use REST API server - params=[{"server": True}], # whether to use REST API server - scope="module", -) -def client(request): - - if request.param["server"]: - # get URL from enviornment - server_url = os.getenv("MEMGPT_SERVER_URL") - if server_url is None: - # run server in thread - server_url = "http://localhost:8283" - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - else: - assert False, "Local client not implemented" - - assert server_url is not None - client = create_client(base_url=server_url) # This yields control back to the test function +@pytest.fixture(scope="module") +def client(): + client = create_client() client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - # Clear all records from the Tool table + yield client @@ -153,8 +129,8 @@ def test_create_agent_tool(client): human = initial_memory.get_block("human") persona = initial_memory.get_block("persona") print("Initial memory:", human, persona) - assert len(human) > 0, "Expected human memory to be non-empty" - assert len(persona) > 0, "Expected persona memory to be non-empty" + assert len(human.value) > 0, "Expected human memory to be non-empty" + assert len(persona.value) > 0, "Expected persona memory to be non-empty" # test agent tool response = client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool") @@ -166,8 +142,8 @@ def test_create_agent_tool(client): human = updated_memory.get_block("human") persona = updated_memory.get_block("persona") print("Updated memory:", human, persona) - assert len(human) == 0, "Expected human memory to be empty" - assert len(persona) == 0, "Expected persona memory to be empty" + assert len(human.value) == 0, "Expected human memory to be empty" + assert len(persona.value) == 0, "Expected persona memory to be empty" def test_custom_import_tool(client):