From 98c5702ef9ca17d6cfbf0e01b407c8082eb1d7fb Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 22 Jan 2025 19:05:41 -0800 Subject: [PATCH] chore: rename metadata_ field to metadata in pydantic (#732) --- .../notebooks/Agentic RAG with Letta.ipynb | 2 +- letta/agent.py | 2 +- letta/chat_only_agent.py | 2 +- letta/client/client.py | 10 +-- letta/data_sources/connectors.py | 2 +- letta/orm/agent.py | 4 +- letta/orm/block.py | 4 +- letta/orm/sqlalchemy_base.py | 5 ++ letta/schemas/agent.py | 6 +- letta/schemas/block.py | 4 +- letta/schemas/job.py | 2 +- letta/schemas/letta_base.py | 6 ++ letta/schemas/passage.py | 2 +- letta/schemas/source.py | 8 +- letta/server/rest_api/routers/v1/agents.py | 6 +- letta/server/rest_api/routers/v1/jobs.py | 4 +- letta/server/rest_api/routers/v1/sources.py | 2 +- letta/server/server.py | 6 +- letta/services/agent_manager.py | 15 ++-- letta/services/block_manager.py | 6 +- letta/services/job_manager.py | 8 +- letta/services/organization_manager.py | 2 +- letta/services/passage_manager.py | 6 +- letta/services/provider_manager.py | 4 +- letta/services/sandbox_config_manager.py | 4 +- letta/services/source_manager.py | 6 +- letta/services/tool_execution_sandbox.py | 4 +- letta/services/tool_manager.py | 6 +- letta/services/user_manager.py | 4 +- tests/helpers/client_helper.py | 4 +- tests/helpers/utils.py | 2 +- tests/test_client_legacy.py | 10 +-- tests/test_managers.py | 73 ++++++++++--------- tests/test_server.py | 13 ++-- 34 files changed, 133 insertions(+), 111 deletions(-) diff --git a/examples/notebooks/Agentic RAG with Letta.ipynb b/examples/notebooks/Agentic RAG with Letta.ipynb index 47df76bd..45ff8973 100644 --- a/examples/notebooks/Agentic RAG with Letta.ipynb +++ b/examples/notebooks/Agentic RAG with Letta.ipynb @@ -101,7 +101,7 @@ } ], "source": [ - "client.jobs.get(job_id=job.id).metadata_" + "client.jobs.get(job_id=job.id).metadata" ] }, { diff --git a/letta/agent.py b/letta/agent.py index cf2cca11..ebcc118c 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1091,7 +1091,7 @@ def save_agent(agent: Agent): embedding_config=agent_state.embedding_config, message_ids=agent_state.message_ids, description=agent_state.description, - metadata_=agent_state.metadata_, + metadata=agent_state.metadata, # TODO: Add this back in later # tool_exec_environment_variables=agent_state.get_agent_env_vars_as_dict(), ) diff --git a/letta/chat_only_agent.py b/letta/chat_only_agent.py index e5f431c5..687763b0 100644 --- a/letta/chat_only_agent.py +++ b/letta/chat_only_agent.py @@ -87,7 +87,7 @@ class ChatOnlyAgent(Agent): memory=offline_memory, llm_config=LLMConfig.default_config("gpt-4"), embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - tool_ids=self.agent_state.metadata_.get("offline_memory_tools", []), + tool_ids=self.agent_state.metadata.get("offline_memory_tools", []), include_base_tools=False, ) self.offline_memory_agent.memory.update_block_value(label="conversation_block", value=recent_convo) diff --git a/letta/client/client.py b/letta/client/client.py index 1cacbd44..8c9f5cf6 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -586,7 +586,7 @@ class RESTClient(AbstractClient): # create agent create_params = { "description": description, - "metadata_": metadata, + "metadata": metadata, "memory_blocks": [], "block_ids": [b.id for b in memory.get_blocks()] + block_ids, "tool_ids": tool_ids, @@ -685,7 +685,7 @@ class RESTClient(AbstractClient): tool_ids=tool_ids, tags=tags, description=description, - metadata_=metadata, + metadata=metadata, llm_config=llm_config, embedding_config=embedding_config, message_ids=message_ids, @@ -2331,7 +2331,7 @@ class LocalClient(AbstractClient): # Create the base parameters create_params = { "description": description, - "metadata_": metadata, + "metadata": metadata, "memory_blocks": [], "block_ids": [b.id for b in memory.get_blocks()] + block_ids, "tool_ids": tool_ids, @@ -2423,7 +2423,7 @@ class LocalClient(AbstractClient): tool_ids=tool_ids, tags=tags, description=description, - metadata_=metadata, + metadata=metadata, llm_config=llm_config, embedding_config=embedding_config, message_ids=message_ids, @@ -3100,7 +3100,7 @@ class LocalClient(AbstractClient): job = Job( user_id=self.user_id, status=JobStatus.created, - metadata_={"type": "embedding", "filename": filename, "source_id": source_id}, + metadata={"type": "embedding", "filename": filename, "source_id": source_id}, ) job = self.server.job_manager.create_job(pydantic_job=job, actor=self.user) diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 8ae67f88..12408ca6 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -77,7 +77,7 @@ def load_data(connector: DataConnector, source: Source, passage_manager: Passage text=passage_text, file_id=file_metadata.id, source_id=source.id, - metadata_=passage_metadata, + metadata=passage_metadata, organization_id=source.organization_id, embedding_config=source.embedding_config, embedding=embedding, diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 23b8d74c..1ea92277 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -113,14 +113,14 @@ class Agent(SqlalchemyBase, OrganizationMixin): "description": self.description, "message_ids": self.message_ids, "tools": self.tools, - "sources": self.sources, + "sources": [source.to_pydantic() for source in self.sources], "tags": [t.tag for t in self.tags], "tool_rules": self.tool_rules, "system": self.system, "agent_type": self.agent_type, "llm_config": self.llm_config, "embedding_config": self.embedding_config, - "metadata_": self.metadata_, + "metadata": self.metadata_, "memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]), "created_by_id": self.created_by_id, "last_updated_by_id": self.last_updated_by_id, diff --git a/letta/orm/block.py b/letta/orm/block.py index 99cfa29b..7395b0af 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -45,7 +45,9 @@ class Block(OrganizationMixin, SqlalchemyBase): Schema = Persona case _: Schema = PydanticBlock - return Schema.model_validate(self) + model_dict = {k: v for k, v in self.__dict__.items() if k in self.__pydantic_model__.model_fields} + model_dict["metadata"] = self.metadata_ + return Schema.model_validate(model_dict) @event.listens_for(Block, "after_update") # Changed from 'before_update' diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 05ada679..fd0c1e3a 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -449,6 +449,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): def to_pydantic(self) -> "BaseModel": """converts to the basic pydantic model counterpart""" + if hasattr(self, "metadata_"): + model_dict = {k: v for k, v in self.__dict__.items() if k in self.__pydantic_model__.model_fields} + model_dict["metadata"] = self.metadata_ + return self.__pydantic_model__.model_validate(model_dict) + return self.__pydantic_model__.model_validate(self) def to_record(self) -> "BaseModel": diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index edcede23..3214307a 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -72,7 +72,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True): organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.") description: Optional[str] = Field(None, description="The description of the agent.") - metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") + metadata: Optional[Dict] = Field(None, description="The metadata of the agent.") memory: Memory = Field(..., description="The in-context memory of the agent.") tools: List[Tool] = Field(..., description="The tools used by the agent.") @@ -122,7 +122,7 @@ class CreateAgent(BaseModel, validate_assignment=True): # False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)." ) description: Optional[str] = Field(None, description="The description of the agent.") - metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") + metadata: Optional[Dict] = Field(None, description="The metadata of the agent.") model: Optional[str] = Field( None, description="The LLM configuration handle used by the agent, specified in the format " @@ -203,7 +203,7 @@ class UpdateAgent(BaseModel): embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") description: Optional[str] = Field(None, description="The description of the agent.") - metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") + metadata: Optional[Dict] = Field(None, description="The metadata of the agent.") tool_exec_environment_variables: Optional[Dict[str, str]] = Field( None, description="The environment variables for tool execution specific to this agent." ) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 25e84b7d..2aa518cb 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -27,7 +27,7 @@ class BaseBlock(LettaBase, validate_assignment=True): # metadata description: Optional[str] = Field(None, description="Description of the block.") - metadata_: Optional[dict] = Field({}, description="Metadata of the block.") + metadata: Optional[dict] = Field({}, description="Metadata of the block.") # def __len__(self): # return len(self.value) @@ -63,7 +63,7 @@ class Block(BaseBlock): label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block. template_name (str): The name of the block template (if it is a template). description (str): Description of the block. - metadata_ (Dict): Metadata of the block. + metadata (Dict): Metadata of the block. user_id (str): The unique identifier of the user associated with the block. """ diff --git a/letta/schemas/job.py b/letta/schemas/job.py index c61c5839..0ffbf2c7 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -12,7 +12,7 @@ class JobBase(OrmMetadataBase): __id_prefix__ = "job" status: JobStatus = Field(default=JobStatus.created, description="The status of the job.") completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.") - metadata_: Optional[dict] = Field(None, description="The metadata of the job.") + metadata: Optional[dict] = Field(None, description="The metadata of the job.") job_type: JobType = Field(default=JobType.JOB, description="The type of the job.") diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index bb29a5be..d6850933 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -88,6 +88,12 @@ class LettaBase(BaseModel): return f"{cls.__id_prefix__}-{v}" return v + def model_dump(self, to_orm: bool = False, **kwargs): + data = super().model_dump(**kwargs) + if to_orm and "metadata" in data: + data["metadata_"] = data.pop("metadata") + return data + class OrmMetadataBase(LettaBase): # metadata fields diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index c1ec13be..648364c2 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -23,7 +23,7 @@ class PassageBase(OrmMetadataBase): # file association file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.") - metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.") + metadata: Optional[Dict] = Field({}, description="The metadata of the passage.") class Passage(PassageBase): diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 0a458dfd..796f50eb 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -24,7 +24,7 @@ class Source(BaseSource): name (str): The name of the source. embedding_config (EmbeddingConfig): The embedding configuration used by the source. user_id (str): The ID of the user that created the source. - metadata_ (dict): Metadata associated with the source. + metadata (dict): Metadata associated with the source. description (str): The description of the source. """ @@ -33,7 +33,7 @@ class Source(BaseSource): description: Optional[str] = Field(None, description="The description of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.") - metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the source.") # metadata fields created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") @@ -54,7 +54,7 @@ class SourceCreate(BaseSource): # optional description: Optional[str] = Field(None, description="The description of the source.") - metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the source.") class SourceUpdate(BaseSource): @@ -64,5 +64,5 @@ class SourceUpdate(BaseSource): name: Optional[str] = Field(None, description="The name of the source.") description: Optional[str] = Field(None, description="The description of the source.") - metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") + metadata: Optional[dict] = Field(None, description="Metadata associated with the source.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 988d9ede..17054666 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -556,7 +556,7 @@ async def process_message_background( job_update = JobUpdate( status=JobStatus.completed, completed_at=datetime.utcnow(), - metadata_={"result": result.model_dump()}, # Store the result in metadata + metadata={"result": result.model_dump()}, # Store the result in metadata ) server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor) @@ -568,7 +568,7 @@ async def process_message_background( job_update = JobUpdate( status=JobStatus.failed, completed_at=datetime.utcnow(), - metadata_={"error": str(e)}, + metadata={"error": str(e)}, ) server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor) raise @@ -596,7 +596,7 @@ async def send_message_async( run = Run( user_id=actor.id, status=JobStatus.created, - metadata_={ + metadata={ "job_type": "send_message_async", "agent_id": agent_id, }, diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 9e6cc78f..4e41490b 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -26,9 +26,9 @@ def list_jobs( jobs = server.job_manager.list_jobs(actor=actor) if source_id: - # can't be in the ORM since we have source_id stored in the metadata_ + # can't be in the ORM since we have source_id stored in the metadata # TODO: Probably change this - jobs = [job for job in jobs if job.metadata_.get("source_id") == source_id] + jobs = [job for job in jobs if job.metadata.get("source_id") == source_id] return jobs diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index b4c0e841..6ce045bb 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -131,7 +131,7 @@ def upload_file_to_source( # create job job = Job( user_id=actor.id, - metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id}, + metadata={"type": "embedding", "filename": file.filename, "source_id": source_id}, completed_at=None, ) job_id = job.id diff --git a/letta/server/server.py b/letta/server/server.py index 04841419..69602eb2 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -956,8 +956,8 @@ class SyncServer(Server): # update job status job.status = JobStatus.completed - job.metadata_["num_passages"] = num_passages - job.metadata_["num_documents"] = num_documents + job.metadata["num_passages"] = num_passages + job.metadata["num_documents"] = num_documents self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor) # update all agents who have this source attached @@ -1019,7 +1019,7 @@ class SyncServer(Server): attached_agents = [{"id": agent.id, "name": agent.name} for agent in agents] # Overwrite metadata field, should be empty anyways - source.metadata_ = dict( + source.metadata = dict( num_documents=num_documents, num_passages=num_passages, attached_agents=attached_agents, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 41e29b78..a84a5df9 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -82,7 +82,7 @@ class AgentManager: block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original if agent_create.memory_blocks: for create_block in agent_create.memory_blocks: - block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor) + block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor) block_ids.append(block.id) # TODO: Remove this block once we deprecate the legacy `tools` field @@ -117,7 +117,7 @@ class AgentManager: source_ids=agent_create.source_ids or [], tags=agent_create.tags or [], description=agent_create.description, - metadata_=agent_create.metadata_, + metadata=agent_create.metadata, tool_rules=agent_create.tool_rules, actor=actor, ) @@ -177,7 +177,7 @@ class AgentManager: source_ids: List[str], tags: List[str], description: Optional[str] = None, - metadata_: Optional[Dict] = None, + metadata: Optional[Dict] = None, tool_rules: Optional[List[PydanticToolRule]] = None, ) -> PydanticAgentState: """Create a new agent.""" @@ -191,7 +191,7 @@ class AgentManager: "embedding_config": embedding_config, "organization_id": actor.organization_id, "description": description, - "metadata_": metadata_, + "metadata_": metadata, "tool_rules": tool_rules, } @@ -242,11 +242,14 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) # Update scalar fields directly - scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"} + scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata"} for field in scalar_fields: value = getattr(agent_update, field, None) if value is not None: - setattr(agent, field, value) + if field == "metadata": + setattr(agent, "metadata_", value) + else: + setattr(agent, field, value) # Update relationships using _process_relationship and _process_tags if agent_update.tool_ids is not None: diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 77eb5e7e..ad46b3bb 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -24,11 +24,11 @@ class BlockManager: """Create a new block based on the Block schema.""" db_block = self.get_block_by_id(block.id, actor) if db_block: - update_data = BlockUpdate(**block.model_dump(exclude_none=True)) + update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True)) self.update_block(block.id, update_data, actor) else: with self.session_maker() as session: - data = block.model_dump(exclude_none=True) + data = block.model_dump(to_orm=True, exclude_none=True) block = BlockModel(**data, organization_id=actor.organization_id) block.create(session, actor=actor) return block.to_pydantic() @@ -40,7 +40,7 @@ class BlockManager: with self.session_maker() as session: block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) - update_data = block_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) for key, value in update_data.items(): setattr(block, key, value) diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index f014c568..7e29e78d 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -39,7 +39,7 @@ class JobManager: with self.session_maker() as session: # Associate the job with the user pydantic_job.user_id = actor.id - job_data = pydantic_job.model_dump() + job_data = pydantic_job.model_dump(to_orm=True) job = JobModel(**job_data) job.create(session, actor=actor) # Save job in the database return job.to_pydantic() @@ -52,7 +52,7 @@ class JobManager: job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"]) # Update job attributes with only the fields that were explicitly set - update_data = job_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) # Automatically update the completion timestamp if status is set to 'completed' if update_data.get("status") == JobStatus.completed and not job.completed_at: @@ -62,7 +62,9 @@ class JobManager: setattr(job, key, value) # Save the updated job to the database - return job.update(db_session=session) # TODO: Add this later , actor=actor) + job.update(db_session=session) # TODO: Add this later , actor=actor) + + return job.to_pydantic() @enforce_types def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index fc86b05f..4f1b2f9f 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -44,7 +44,7 @@ class OrganizationManager: @enforce_types def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: with self.session_maker() as session: - org = OrganizationModel(**pydantic_org.model_dump()) + org = OrganizationModel(**pydantic_org.model_dump(to_orm=True)) org.create(session) return org.to_pydantic() diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index f80e0160..5a2f0f5d 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -38,14 +38,14 @@ class PassageManager: def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" # Common fields for both passage types - data = pydantic_passage.model_dump() + data = pydantic_passage.model_dump(to_orm=True) common_fields = { "id": data.get("id"), "text": data["text"], "embedding": data["embedding"], "embedding_config": data["embedding_config"], "organization_id": data["organization_id"], - "metadata_": data.get("metadata_", {}), + "metadata_": data.get("metadata", {}), "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.utcnow()), } @@ -145,7 +145,7 @@ class PassageManager: raise ValueError(f"Passage with id {passage_id} does not exist.") # Update the database record with values from the provided record - update_data = passage.model_dump(exclude_unset=True, exclude_none=True) + update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) for key, value in update_data.items(): setattr(curr_passage, key, value) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 989e7eb7..1e32d588 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -24,7 +24,7 @@ class ProviderManager: # Lazily create the provider id prior to persistence provider.resolve_identifier() - new_provider = ProviderModel(**provider.model_dump(exclude_unset=True)) + new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True)) new_provider.create(session) return new_provider.to_pydantic() @@ -36,7 +36,7 @@ class ProviderManager: existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id) # Update only the fields that are provided in ProviderUpdate - update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) for key, value in update_data.items(): setattr(existing_provider, key, value) diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 63313c87..8411078c 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -172,7 +172,7 @@ class SandboxConfigManager: return db_env_var else: with self.session_maker() as session: - env_var = SandboxEnvVarModel(**env_var.model_dump(exclude_none=True)) + env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True)) env_var.create(session, actor=actor) return env_var.to_pydantic() @@ -183,7 +183,7 @@ class SandboxConfigManager: """Update an existing sandbox environment variable.""" with self.session_maker() as session: env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) - update_data = env_var_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value} if update_data: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index a5804347..fd15a0a1 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -30,7 +30,7 @@ class SourceManager: with self.session_maker() as session: # Provide default embedding config if not given source.organization_id = actor.organization_id - source = SourceModel(**source.model_dump(exclude_none=True)) + source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True)) source.create(session, actor=actor) return source.to_pydantic() @@ -41,7 +41,7 @@ class SourceManager: source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) # get update dictionary - update_data = source_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) # Remove redundant update fields update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value} @@ -132,7 +132,7 @@ class SourceManager: else: with self.session_maker() as session: file_metadata.organization_id = actor.organization_id - file_metadata = FileMetadataModel(**file_metadata.model_dump(exclude_none=True)) + file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True)) file_metadata.create(session, actor=actor) return file_metadata.to_pydantic() diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index 1d7b0d73..5016ce1b 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -364,7 +364,9 @@ class ToolExecutionSandbox: sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}) else: # no template - sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"})) + sbx = Sandbox( + metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(to_orm=True, exclude={"pip_requirements"}) + ) # install pip requirements if e2b_config.pip_requirements: diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 3a66aaa3..9facb153 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -39,7 +39,7 @@ class ToolManager: tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor) if tool: # Put to dict and remove fields that should not be reset - update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True) + update_data = pydantic_tool.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) # If there's anything to update if update_data: @@ -67,7 +67,7 @@ class ToolManager: # Auto-generate description if not provided if pydantic_tool.description is None: pydantic_tool.description = pydantic_tool.json_schema.get("description", None) - tool_data = pydantic_tool.model_dump() + tool_data = pydantic_tool.model_dump(to_orm=True) tool = ToolModel(**tool_data) tool.create(session, actor=actor) # Re-raise other database-related errors @@ -112,7 +112,7 @@ class ToolManager: tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) # Update tool attributes with only the fields that were explicitly set - update_data = tool_update.model_dump(exclude_none=True) + update_data = tool_update.model_dump(to_orm=True, exclude_none=True) for key, value in update_data.items(): setattr(tool, key, value) diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 5dca0fff..061f443e 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -45,7 +45,7 @@ class UserManager: def create_user(self, pydantic_user: PydanticUser) -> PydanticUser: """Create a new user if it doesn't already exist.""" with self.session_maker() as session: - new_user = UserModel(**pydantic_user.model_dump()) + new_user = UserModel(**pydantic_user.model_dump(to_orm=True)) new_user.create(session) return new_user.to_pydantic() @@ -57,7 +57,7 @@ class UserManager: existing_user = UserModel.read(db_session=session, identifier=user_update.id) # Update only the fields that are provided in UserUpdate - update_data = user_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) for key, value in update_data.items(): setattr(existing_user, key, value) diff --git a/tests/helpers/client_helper.py b/tests/helpers/client_helper.py index e7cce8ef..815102a8 100644 --- a/tests/helpers/client_helper.py +++ b/tests/helpers/client_helper.py @@ -10,14 +10,14 @@ from letta.schemas.source import Source def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job: # load a file into a source (non-blocking job) upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False) - print("Upload job", upload_job, upload_job.status, upload_job.metadata_) + print("Upload job", upload_job, upload_job.status, upload_job.metadata) # view active jobs active_jobs = client.list_active_jobs() jobs = client.list_jobs() assert upload_job.id in [j.id for j in jobs] assert len(active_jobs) == 1 - assert active_jobs[0].metadata_["source_id"] == source.id + assert active_jobs[0].metadata["source_id"] == source.id # wait for job to finish (with timeout) timeout = 240 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 765c4612..167a39ee 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -96,7 +96,7 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up # Assert scalar fields assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}" assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}" - assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}" + assert agent.metadata == request.metadata, f"Metadata mismatch: {agent.metadata} != {request.metadata}" # Assert agent env vars if hasattr(request, "tool_exec_environment_variables"): diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 882c4e7e..e62b1833 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -466,8 +466,8 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): assert len(sources) == 1 # TODO: add back? - assert sources[0].metadata_["num_passages"] == 0 - assert sources[0].metadata_["num_documents"] == 0 + assert sources[0].metadata["num_passages"] == 0 + assert sources[0].metadata["num_documents"] == 0 # update the source original_id = source.id @@ -491,7 +491,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): filename = "tests/data/memgpt_paper.pdf" upload_job = upload_file_using_client(client, source, filename) job = client.get_job(upload_job.id) - created_passages = job.metadata_["num_passages"] + created_passages = job.metadata["num_passages"] # TODO: add test for blocking job @@ -515,8 +515,8 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # check number of passages sources = client.list_sources() # TODO: add back? - # assert sources.sources[0].metadata_["num_passages"] > 0 - # assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added + # assert sources.sources[0].metadata["num_passages"] > 0 + # assert sources.sources[0].metadata["num_documents"] == 0 # TODO: fix this once document store added print(sources) # detach the source diff --git a/tests/test_managers.py b/tests/test_managers.py index ad825db5..f4ee1306 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -134,7 +134,7 @@ def default_source(server: SyncServer, default_user): source_pydantic = PydanticSource( name="Test Source", description="This is a test source.", - metadata_={"type": "test"}, + metadata={"type": "test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source=source_pydantic, actor=default_user) @@ -146,7 +146,7 @@ def other_source(server: SyncServer, default_user): source_pydantic = PydanticSource( name="Another Test Source", description="This is yet another test source.", - metadata_={"type": "another_test"}, + metadata={"type": "another_test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source=source_pydantic, actor=default_user) @@ -235,7 +235,7 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"}, + metadata={"type": "test"}, ), actor=default_user, ) @@ -253,7 +253,7 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"}, + metadata={"type": "test"}, ), actor=default_user, ) @@ -273,7 +273,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"}, + metadata={"type": "test"}, ), actor=default_user, ) @@ -291,7 +291,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, - metadata_={"type": "test"}, + metadata={"type": "test"}, ), actor=default_user, ) @@ -348,7 +348,7 @@ def default_block(server: SyncServer, default_user): value="Default Block Content", description="A default test block", limit=1000, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) block = block_manager.create_or_update_block(block_data, actor=default_user) yield block @@ -363,7 +363,7 @@ def other_block(server: SyncServer, default_user): value="Other Block Content", description="Another test block", limit=500, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) block = block_manager.create_or_update_block(block_data, actor=default_user) yield block @@ -444,7 +444,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too source_ids=[default_source.id], tags=["a", "b"], description="test_description", - metadata_={"test_key": "test_value"}, + metadata={"test_key": "test_value"}, tool_rules=[InitToolRule(tool_name=print_tool.name)], initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")], tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"}, @@ -600,7 +600,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(model_name="letta"), message_ids=["10", "20"], - metadata_={"train_key": "train_value"}, + metadata={"train_key": "train_value"}, tool_exec_environment_variables={"test_env_var_key_a": "a", "new_tool_exec_key": "n"}, ) @@ -1938,7 +1938,7 @@ def test_create_block(server: SyncServer, default_user): template_name="sample_template", description="A test block", limit=1000, - metadata_={"example": "data"}, + metadata={"example": "data"}, ) block = block_manager.create_or_update_block(block_create, actor=default_user) @@ -1950,7 +1950,7 @@ def test_create_block(server: SyncServer, default_user): assert block.template_name == block_create.template_name assert block.description == block_create.description assert block.limit == block_create.limit - assert block.metadata_ == block_create.metadata_ + assert block.metadata == block_create.metadata assert block.organization_id == default_user.organization_id @@ -2033,7 +2033,7 @@ def test_create_source(server: SyncServer, default_user): source_pydantic = PydanticSource( name="Test Source", description="This is a test source.", - metadata_={"type": "test"}, + metadata={"type": "test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source=source_pydantic, actor=default_user) @@ -2041,7 +2041,7 @@ def test_create_source(server: SyncServer, default_user): # Assertions to check the created source assert source.name == source_pydantic.name assert source.description == source_pydantic.description - assert source.metadata_ == source_pydantic.metadata_ + assert source.metadata == source_pydantic.metadata assert source.organization_id == default_user.organization_id @@ -2051,14 +2051,14 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul source_pydantic = PydanticSource( name=name, description="This is a test source.", - metadata_={"type": "medical"}, + metadata={"type": "medical"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) source = server.source_manager.create_source(source=source_pydantic, actor=default_user) source_pydantic = PydanticSource( name=name, description="This is a different test source.", - metadata_={"type": "legal"}, + metadata={"type": "legal"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user) @@ -2073,13 +2073,13 @@ def test_update_source(server: SyncServer, default_user): source = server.source_manager.create_source(source=source_pydantic, actor=default_user) # Update the source - update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata_={"type": "updated"}) + update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata={"type": "updated"}) updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) # Assertions to verify update assert updated_source.name == update_data.name assert updated_source.description == update_data.description - assert updated_source.metadata_ == update_data.metadata_ + assert updated_source.metadata == update_data.metadata def test_delete_source(server: SyncServer, default_user): @@ -2411,7 +2411,7 @@ def test_create_job(server: SyncServer, default_user): """Test creating a job.""" job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) created_job = server.job_manager.create_job(job_data, actor=default_user) @@ -2419,7 +2419,7 @@ def test_create_job(server: SyncServer, default_user): # Assertions to ensure the created job matches the expected values assert created_job.user_id == default_user.id assert created_job.status == JobStatus.created - assert created_job.metadata_ == {"type": "test"} + assert created_job.metadata == {"type": "test"} def test_get_job_by_id(server: SyncServer, default_user): @@ -2427,7 +2427,7 @@ def test_get_job_by_id(server: SyncServer, default_user): # Create a job job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) created_job = server.job_manager.create_job(job_data, actor=default_user) @@ -2437,7 +2437,7 @@ def test_get_job_by_id(server: SyncServer, default_user): # Assertions to ensure the fetched job matches the created job assert fetched_job.id == created_job.id assert fetched_job.status == JobStatus.created - assert fetched_job.metadata_ == {"type": "test"} + assert fetched_job.metadata == {"type": "test"} def test_list_jobs(server: SyncServer, default_user): @@ -2446,7 +2446,7 @@ def test_list_jobs(server: SyncServer, default_user): for i in range(3): job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": f"test-{i}"}, + metadata={"type": f"test-{i}"}, ) server.job_manager.create_job(job_data, actor=default_user) @@ -2456,7 +2456,7 @@ def test_list_jobs(server: SyncServer, default_user): # Assertions to check that the created jobs are listed assert len(jobs) == 3 assert all(job.user_id == default_user.id for job in jobs) - assert all(job.metadata_["type"].startswith("test") for job in jobs) + assert all(job.metadata["type"].startswith("test") for job in jobs) def test_update_job_by_id(server: SyncServer, default_user): @@ -2464,17 +2464,18 @@ def test_update_job_by_id(server: SyncServer, default_user): # Create a job job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) created_job = server.job_manager.create_job(job_data, actor=default_user) + assert created_job.metadata == {"type": "test"} # Update the job - update_data = JobUpdate(status=JobStatus.completed, metadata_={"type": "updated"}) + update_data = JobUpdate(status=JobStatus.completed, metadata={"type": "updated"}) updated_job = server.job_manager.update_job_by_id(created_job.id, update_data, actor=default_user) # Assertions to ensure the job was updated assert updated_job.status == JobStatus.completed - assert updated_job.metadata_ == {"type": "updated"} + assert updated_job.metadata == {"type": "updated"} assert updated_job.completed_at is not None @@ -2483,7 +2484,7 @@ def test_delete_job_by_id(server: SyncServer, default_user): # Create a job job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) created_job = server.job_manager.create_job(job_data, actor=default_user) @@ -2500,7 +2501,7 @@ def test_update_job_auto_complete(server: SyncServer, default_user): # Create a job job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": "test"}, + metadata={"type": "test"}, ) created_job = server.job_manager.create_job(job_data, actor=default_user) @@ -2533,7 +2534,7 @@ def test_list_jobs_pagination(server: SyncServer, default_user): for i in range(10): job_data = PydanticJob( status=JobStatus.created, - metadata_={"type": f"test-{i}"}, + metadata={"type": f"test-{i}"}, ) server.job_manager.create_job(job_data, actor=default_user) @@ -2550,15 +2551,15 @@ def test_list_jobs_by_status(server: SyncServer, default_user): # Create multiple jobs with different statuses job_data_created = PydanticJob( status=JobStatus.created, - metadata_={"type": "test-created"}, + metadata={"type": "test-created"}, ) job_data_in_progress = PydanticJob( status=JobStatus.running, - metadata_={"type": "test-running"}, + metadata={"type": "test-running"}, ) job_data_completed = PydanticJob( status=JobStatus.completed, - metadata_={"type": "test-completed"}, + metadata={"type": "test-completed"}, ) server.job_manager.create_job(job_data_created, actor=default_user) @@ -2572,13 +2573,13 @@ def test_list_jobs_by_status(server: SyncServer, default_user): # Assertions assert len(created_jobs) == 1 - assert created_jobs[0].metadata_["type"] == job_data_created.metadata_["type"] + assert created_jobs[0].metadata["type"] == job_data_created.metadata["type"] assert len(in_progress_jobs) == 1 - assert in_progress_jobs[0].metadata_["type"] == job_data_in_progress.metadata_["type"] + assert in_progress_jobs[0].metadata["type"] == job_data_in_progress.metadata["type"] assert len(completed_jobs) == 1 - assert completed_jobs[0].metadata_["type"] == job_data_completed.metadata_["type"] + assert completed_jobs[0].metadata["type"] == job_data_completed.metadata["type"] def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job): diff --git a/tests/test_server.py b/tests/test_server.py index 53d973e1..cbca00bb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -930,6 +930,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot ), actor=actor, ) + assert source.created_by_id == user_id # Create a test file with some content test_file = tmp_path / "test.txt" @@ -943,7 +944,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot job = server.job_manager.create_job( PydanticJob( user_id=user_id, - metadata_={"type": "embedding", "filename": test_file.name, "source_id": source.id}, + metadata={"type": "embedding", "filename": test_file.name, "source_id": source.id}, ), actor=actor, ) @@ -959,8 +960,8 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot # Verify job completed successfully job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor) assert job.status == "completed" - assert job.metadata_["num_passages"] == 1 - assert job.metadata_["num_documents"] == 1 + assert job.metadata["num_passages"] == 1 + assert job.metadata["num_documents"] == 1 # Verify passages were added first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) @@ -974,7 +975,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot job2 = server.job_manager.create_job( PydanticJob( user_id=user_id, - metadata_={"type": "embedding", "filename": test_file2.name, "source_id": source.id}, + metadata={"type": "embedding", "filename": test_file2.name, "source_id": source.id}, ), actor=actor, ) @@ -990,8 +991,8 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot # Verify second job completed successfully job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor) assert job2.status == "completed" - assert job2.metadata_["num_passages"] >= 10 - assert job2.metadata_["num_documents"] == 1 + assert job2.metadata["num_passages"] >= 10 + assert job2.metadata["num_documents"] == 1 # Verify passages were appended (not replaced) final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)