chore: rename metadata_ field to metadata in pydantic (#732)

This commit is contained in:
cthomas
2025-01-22 19:05:41 -08:00
committed by GitHub
parent cc638e3593
commit 98c5702ef9
34 changed files with 133 additions and 111 deletions

View File

@@ -101,7 +101,7 @@
}
],
"source": [
"client.jobs.get(job_id=job.id).metadata_"
"client.jobs.get(job_id=job.id).metadata"
]
},
{

View File

@@ -1091,7 +1091,7 @@ def save_agent(agent: Agent):
embedding_config=agent_state.embedding_config,
message_ids=agent_state.message_ids,
description=agent_state.description,
metadata_=agent_state.metadata_,
metadata=agent_state.metadata,
# TODO: Add this back in later
# tool_exec_environment_variables=agent_state.get_agent_env_vars_as_dict(),
)

View File

@@ -87,7 +87,7 @@ class ChatOnlyAgent(Agent):
memory=offline_memory,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
tool_ids=self.agent_state.metadata_.get("offline_memory_tools", []),
tool_ids=self.agent_state.metadata.get("offline_memory_tools", []),
include_base_tools=False,
)
self.offline_memory_agent.memory.update_block_value(label="conversation_block", value=recent_convo)

View File

@@ -586,7 +586,7 @@ class RESTClient(AbstractClient):
# create agent
create_params = {
"description": description,
"metadata_": metadata,
"metadata": metadata,
"memory_blocks": [],
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
@@ -685,7 +685,7 @@ class RESTClient(AbstractClient):
tool_ids=tool_ids,
tags=tags,
description=description,
metadata_=metadata,
metadata=metadata,
llm_config=llm_config,
embedding_config=embedding_config,
message_ids=message_ids,
@@ -2331,7 +2331,7 @@ class LocalClient(AbstractClient):
# Create the base parameters
create_params = {
"description": description,
"metadata_": metadata,
"metadata": metadata,
"memory_blocks": [],
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
@@ -2423,7 +2423,7 @@ class LocalClient(AbstractClient):
tool_ids=tool_ids,
tags=tags,
description=description,
metadata_=metadata,
metadata=metadata,
llm_config=llm_config,
embedding_config=embedding_config,
message_ids=message_ids,
@@ -3100,7 +3100,7 @@ class LocalClient(AbstractClient):
job = Job(
user_id=self.user_id,
status=JobStatus.created,
metadata_={"type": "embedding", "filename": filename, "source_id": source_id},
metadata={"type": "embedding", "filename": filename, "source_id": source_id},
)
job = self.server.job_manager.create_job(pydantic_job=job, actor=self.user)

View File

@@ -77,7 +77,7 @@ def load_data(connector: DataConnector, source: Source, passage_manager: Passage
text=passage_text,
file_id=file_metadata.id,
source_id=source.id,
metadata_=passage_metadata,
metadata=passage_metadata,
organization_id=source.organization_id,
embedding_config=source.embedding_config,
embedding=embedding,

View File

@@ -113,14 +113,14 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"description": self.description,
"message_ids": self.message_ids,
"tools": self.tools,
"sources": self.sources,
"sources": [source.to_pydantic() for source in self.sources],
"tags": [t.tag for t in self.tags],
"tool_rules": self.tool_rules,
"system": self.system,
"agent_type": self.agent_type,
"llm_config": self.llm_config,
"embedding_config": self.embedding_config,
"metadata_": self.metadata_,
"metadata": self.metadata_,
"memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
"created_by_id": self.created_by_id,
"last_updated_by_id": self.last_updated_by_id,

View File

@@ -45,7 +45,9 @@ class Block(OrganizationMixin, SqlalchemyBase):
Schema = Persona
case _:
Schema = PydanticBlock
return Schema.model_validate(self)
model_dict = {k: v for k, v in self.__dict__.items() if k in self.__pydantic_model__.model_fields}
model_dict["metadata"] = self.metadata_
return Schema.model_validate(model_dict)
@event.listens_for(Block, "after_update") # Changed from 'before_update'

View File

@@ -449,6 +449,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
def to_pydantic(self) -> "BaseModel":
"""converts to the basic pydantic model counterpart"""
if hasattr(self, "metadata_"):
model_dict = {k: v for k, v in self.__dict__.items() if k in self.__pydantic_model__.model_fields}
model_dict["metadata"] = self.metadata_
return self.__pydantic_model__.model_validate(model_dict)
return self.__pydantic_model__.model_validate(self)
def to_record(self) -> "BaseModel":

View File

@@ -72,7 +72,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
memory: Memory = Field(..., description="The in-context memory of the agent.")
tools: List[Tool] = Field(..., description="The tools used by the agent.")
@@ -122,7 +122,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
)
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
model: Optional[str] = Field(
None,
description="The LLM configuration handle used by the agent, specified in the format "
@@ -203,7 +203,7 @@ class UpdateAgent(BaseModel):
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
None, description="The environment variables for tool execution specific to this agent."
)

View File

@@ -27,7 +27,7 @@ class BaseBlock(LettaBase, validate_assignment=True):
# metadata
description: Optional[str] = Field(None, description="Description of the block.")
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
metadata: Optional[dict] = Field({}, description="Metadata of the block.")
# def __len__(self):
# return len(self.value)
@@ -63,7 +63,7 @@ class Block(BaseBlock):
label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block.
template_name (str): The name of the block template (if it is a template).
description (str): Description of the block.
metadata_ (Dict): Metadata of the block.
metadata (Dict): Metadata of the block.
user_id (str): The unique identifier of the user associated with the block.
"""

View File

@@ -12,7 +12,7 @@ class JobBase(OrmMetadataBase):
__id_prefix__ = "job"
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
metadata_: Optional[dict] = Field(None, description="The metadata of the job.")
metadata: Optional[dict] = Field(None, description="The metadata of the job.")
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")

View File

@@ -88,6 +88,12 @@ class LettaBase(BaseModel):
return f"{cls.__id_prefix__}-{v}"
return v
def model_dump(self, to_orm: bool = False, **kwargs):
data = super().model_dump(**kwargs)
if to_orm and "metadata" in data:
data["metadata_"] = data.pop("metadata")
return data
class OrmMetadataBase(LettaBase):
# metadata fields

View File

@@ -23,7 +23,7 @@ class PassageBase(OrmMetadataBase):
# file association
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
class Passage(PassageBase):

View File

@@ -24,7 +24,7 @@ class Source(BaseSource):
name (str): The name of the source.
embedding_config (EmbeddingConfig): The embedding configuration used by the source.
user_id (str): The ID of the user that created the source.
metadata_ (dict): Metadata associated with the source.
metadata (dict): Metadata associated with the source.
description (str): The description of the source.
"""
@@ -33,7 +33,7 @@ class Source(BaseSource):
description: Optional[str] = Field(None, description="The description of the source.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.")
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
@@ -54,7 +54,7 @@ class SourceCreate(BaseSource):
# optional
description: Optional[str] = Field(None, description="The description of the source.")
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
class SourceUpdate(BaseSource):
@@ -64,5 +64,5 @@ class SourceUpdate(BaseSource):
name: Optional[str] = Field(None, description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")

View File

@@ -556,7 +556,7 @@ async def process_message_background(
job_update = JobUpdate(
status=JobStatus.completed,
completed_at=datetime.utcnow(),
metadata_={"result": result.model_dump()}, # Store the result in metadata
metadata={"result": result.model_dump()}, # Store the result in metadata
)
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
@@ -568,7 +568,7 @@ async def process_message_background(
job_update = JobUpdate(
status=JobStatus.failed,
completed_at=datetime.utcnow(),
metadata_={"error": str(e)},
metadata={"error": str(e)},
)
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
raise
@@ -596,7 +596,7 @@ async def send_message_async(
run = Run(
user_id=actor.id,
status=JobStatus.created,
metadata_={
metadata={
"job_type": "send_message_async",
"agent_id": agent_id,
},

View File

@@ -26,9 +26,9 @@ def list_jobs(
jobs = server.job_manager.list_jobs(actor=actor)
if source_id:
# can't be in the ORM since we have source_id stored in the metadata_
# can't be in the ORM since we have source_id stored in the metadata
# TODO: Probably change this
jobs = [job for job in jobs if job.metadata_.get("source_id") == source_id]
jobs = [job for job in jobs if job.metadata.get("source_id") == source_id]
return jobs

View File

@@ -131,7 +131,7 @@ def upload_file_to_source(
# create job
job = Job(
user_id=actor.id,
metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id},
metadata={"type": "embedding", "filename": file.filename, "source_id": source_id},
completed_at=None,
)
job_id = job.id

View File

@@ -956,8 +956,8 @@ class SyncServer(Server):
# update job status
job.status = JobStatus.completed
job.metadata_["num_passages"] = num_passages
job.metadata_["num_documents"] = num_documents
job.metadata["num_passages"] = num_passages
job.metadata["num_documents"] = num_documents
self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
# update all agents who have this source attached
@@ -1019,7 +1019,7 @@ class SyncServer(Server):
attached_agents = [{"id": agent.id, "name": agent.name} for agent in agents]
# Overwrite metadata field, should be empty anyways
source.metadata_ = dict(
source.metadata = dict(
num_documents=num_documents,
num_passages=num_passages,
attached_agents=attached_agents,

View File

@@ -82,7 +82,7 @@ class AgentManager:
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
if agent_create.memory_blocks:
for create_block in agent_create.memory_blocks:
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor)
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor)
block_ids.append(block.id)
# TODO: Remove this block once we deprecate the legacy `tools` field
@@ -117,7 +117,7 @@ class AgentManager:
source_ids=agent_create.source_ids or [],
tags=agent_create.tags or [],
description=agent_create.description,
metadata_=agent_create.metadata_,
metadata=agent_create.metadata,
tool_rules=agent_create.tool_rules,
actor=actor,
)
@@ -177,7 +177,7 @@ class AgentManager:
source_ids: List[str],
tags: List[str],
description: Optional[str] = None,
metadata_: Optional[Dict] = None,
metadata: Optional[Dict] = None,
tool_rules: Optional[List[PydanticToolRule]] = None,
) -> PydanticAgentState:
"""Create a new agent."""
@@ -191,7 +191,7 @@ class AgentManager:
"embedding_config": embedding_config,
"organization_id": actor.organization_id,
"description": description,
"metadata_": metadata_,
"metadata_": metadata,
"tool_rules": tool_rules,
}
@@ -242,11 +242,14 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Update scalar fields directly
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"}
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata"}
for field in scalar_fields:
value = getattr(agent_update, field, None)
if value is not None:
setattr(agent, field, value)
if field == "metadata":
setattr(agent, "metadata_", value)
else:
setattr(agent, field, value)
# Update relationships using _process_relationship and _process_tags
if agent_update.tool_ids is not None:

View File

@@ -24,11 +24,11 @@ class BlockManager:
"""Create a new block based on the Block schema."""
db_block = self.get_block_by_id(block.id, actor)
if db_block:
update_data = BlockUpdate(**block.model_dump(exclude_none=True))
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
self.update_block(block.id, update_data, actor)
else:
with self.session_maker() as session:
data = block.model_dump(exclude_none=True)
data = block.model_dump(to_orm=True, exclude_none=True)
block = BlockModel(**data, organization_id=actor.organization_id)
block.create(session, actor=actor)
return block.to_pydantic()
@@ -40,7 +40,7 @@ class BlockManager:
with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(block, key, value)

View File

@@ -39,7 +39,7 @@ class JobManager:
with self.session_maker() as session:
# Associate the job with the user
pydantic_job.user_id = actor.id
job_data = pydantic_job.model_dump()
job_data = pydantic_job.model_dump(to_orm=True)
job = JobModel(**job_data)
job.create(session, actor=actor) # Save job in the database
return job.to_pydantic()
@@ -52,7 +52,7 @@ class JobManager:
job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"])
# Update job attributes with only the fields that were explicitly set
update_data = job_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Automatically update the completion timestamp if status is set to 'completed'
if update_data.get("status") == JobStatus.completed and not job.completed_at:
@@ -62,7 +62,9 @@ class JobManager:
setattr(job, key, value)
# Save the updated job to the database
return job.update(db_session=session) # TODO: Add this later , actor=actor)
job.update(db_session=session) # TODO: Add this later , actor=actor)
return job.to_pydantic()
@enforce_types
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:

View File

@@ -44,7 +44,7 @@ class OrganizationManager:
@enforce_types
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
with self.session_maker() as session:
org = OrganizationModel(**pydantic_org.model_dump())
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
org.create(session)
return org.to_pydantic()

View File

@@ -38,14 +38,14 @@ class PassageManager:
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
# Common fields for both passage types
data = pydantic_passage.model_dump()
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata_", {}),
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.utcnow()),
}
@@ -145,7 +145,7 @@ class PassageManager:
raise ValueError(f"Passage with id {passage_id} does not exist.")
# Update the database record with values from the provided record
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(curr_passage, key, value)

View File

@@ -24,7 +24,7 @@ class ProviderManager:
# Lazily create the provider id prior to persistence
provider.resolve_identifier()
new_provider = ProviderModel(**provider.model_dump(exclude_unset=True))
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
new_provider.create(session)
return new_provider.to_pydantic()
@@ -36,7 +36,7 @@ class ProviderManager:
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
# Update only the fields that are provided in ProviderUpdate
update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_provider, key, value)

View File

@@ -172,7 +172,7 @@ class SandboxConfigManager:
return db_env_var
else:
with self.session_maker() as session:
env_var = SandboxEnvVarModel(**env_var.model_dump(exclude_none=True))
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
env_var.create(session, actor=actor)
return env_var.to_pydantic()
@@ -183,7 +183,7 @@ class SandboxConfigManager:
"""Update an existing sandbox environment variable."""
with self.session_maker() as session:
env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor)
update_data = env_var_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value}
if update_data:

View File

@@ -30,7 +30,7 @@ class SourceManager:
with self.session_maker() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(exclude_none=True))
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
source.create(session, actor=actor)
return source.to_pydantic()
@@ -41,7 +41,7 @@ class SourceManager:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
# get update dictionary
update_data = source_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value}
@@ -132,7 +132,7 @@ class SourceManager:
else:
with self.session_maker() as session:
file_metadata.organization_id = actor.organization_id
file_metadata = FileMetadataModel(**file_metadata.model_dump(exclude_none=True))
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
file_metadata.create(session, actor=actor)
return file_metadata.to_pydantic()

View File

@@ -364,7 +364,9 @@ class ToolExecutionSandbox:
sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash})
else:
# no template
sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"}))
sbx = Sandbox(
metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(to_orm=True, exclude={"pip_requirements"})
)
# install pip requirements
if e2b_config.pip_requirements:

View File

@@ -39,7 +39,7 @@ class ToolManager:
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool:
# Put to dict and remove fields that should not be reset
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
update_data = pydantic_tool.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# If there's anything to update
if update_data:
@@ -67,7 +67,7 @@ class ToolManager:
# Auto-generate description if not provided
if pydantic_tool.description is None:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
tool_data = pydantic_tool.model_dump()
tool_data = pydantic_tool.model_dump(to_orm=True)
tool = ToolModel(**tool_data)
tool.create(session, actor=actor) # Re-raise other database-related errors
@@ -112,7 +112,7 @@ class ToolManager:
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
# Update tool attributes with only the fields that were explicitly set
update_data = tool_update.model_dump(exclude_none=True)
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
for key, value in update_data.items():
setattr(tool, key, value)

View File

@@ -45,7 +45,7 @@ class UserManager:
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist."""
with self.session_maker() as session:
new_user = UserModel(**pydantic_user.model_dump())
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
new_user.create(session)
return new_user.to_pydantic()
@@ -57,7 +57,7 @@ class UserManager:
existing_user = UserModel.read(db_session=session, identifier=user_update.id)
# Update only the fields that are provided in UserUpdate
update_data = user_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_user, key, value)

View File

@@ -10,14 +10,14 @@ from letta.schemas.source import Source
def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job:
# load a file into a source (non-blocking job)
upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
print("Upload job", upload_job, upload_job.status, upload_job.metadata)
# view active jobs
active_jobs = client.list_active_jobs()
jobs = client.list_jobs()
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
assert active_jobs[0].metadata_["source_id"] == source.id
assert active_jobs[0].metadata["source_id"] == source.id
# wait for job to finish (with timeout)
timeout = 240

View File

@@ -96,7 +96,7 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
# Assert scalar fields
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
assert agent.metadata == request.metadata, f"Metadata mismatch: {agent.metadata} != {request.metadata}"
# Assert agent env vars
if hasattr(request, "tool_exec_environment_variables"):

View File

@@ -466,8 +466,8 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
assert len(sources) == 1
# TODO: add back?
assert sources[0].metadata_["num_passages"] == 0
assert sources[0].metadata_["num_documents"] == 0
assert sources[0].metadata["num_passages"] == 0
assert sources[0].metadata["num_documents"] == 0
# update the source
original_id = source.id
@@ -491,7 +491,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
filename = "tests/data/memgpt_paper.pdf"
upload_job = upload_file_using_client(client, source, filename)
job = client.get_job(upload_job.id)
created_passages = job.metadata_["num_passages"]
created_passages = job.metadata["num_passages"]
# TODO: add test for blocking job
@@ -515,8 +515,8 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# check number of passages
sources = client.list_sources()
# TODO: add back?
# assert sources.sources[0].metadata_["num_passages"] > 0
# assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
# assert sources.sources[0].metadata["num_passages"] > 0
# assert sources.sources[0].metadata["num_documents"] == 0 # TODO: fix this once document store added
print(sources)
# detach the source

View File

@@ -134,7 +134,7 @@ def default_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Test Source",
description="This is a test source.",
metadata_={"type": "test"},
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
@@ -146,7 +146,7 @@ def other_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Another Test Source",
description="This is yet another test source.",
metadata_={"type": "another_test"},
metadata={"type": "another_test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
@@ -235,7 +235,7 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"},
metadata={"type": "test"},
),
actor=default_user,
)
@@ -253,7 +253,7 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"},
metadata={"type": "test"},
),
actor=default_user,
)
@@ -273,7 +273,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"},
metadata={"type": "test"},
),
actor=default_user,
)
@@ -291,7 +291,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"},
metadata={"type": "test"},
),
actor=default_user,
)
@@ -348,7 +348,7 @@ def default_block(server: SyncServer, default_user):
value="Default Block Content",
description="A default test block",
limit=1000,
metadata_={"type": "test"},
metadata={"type": "test"},
)
block = block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@@ -363,7 +363,7 @@ def other_block(server: SyncServer, default_user):
value="Other Block Content",
description="Another test block",
limit=500,
metadata_={"type": "test"},
metadata={"type": "test"},
)
block = block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@@ -444,7 +444,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
source_ids=[default_source.id],
tags=["a", "b"],
description="test_description",
metadata_={"test_key": "test_value"},
metadata={"test_key": "test_value"},
tool_rules=[InitToolRule(tool_name=print_tool.name)],
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
@@ -600,7 +600,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(model_name="letta"),
message_ids=["10", "20"],
metadata_={"train_key": "train_value"},
metadata={"train_key": "train_value"},
tool_exec_environment_variables={"test_env_var_key_a": "a", "new_tool_exec_key": "n"},
)
@@ -1938,7 +1938,7 @@ def test_create_block(server: SyncServer, default_user):
template_name="sample_template",
description="A test block",
limit=1000,
metadata_={"example": "data"},
metadata={"example": "data"},
)
block = block_manager.create_or_update_block(block_create, actor=default_user)
@@ -1950,7 +1950,7 @@ def test_create_block(server: SyncServer, default_user):
assert block.template_name == block_create.template_name
assert block.description == block_create.description
assert block.limit == block_create.limit
assert block.metadata_ == block_create.metadata_
assert block.metadata == block_create.metadata
assert block.organization_id == default_user.organization_id
@@ -2033,7 +2033,7 @@ def test_create_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Test Source",
description="This is a test source.",
metadata_={"type": "test"},
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
@@ -2041,7 +2041,7 @@ def test_create_source(server: SyncServer, default_user):
# Assertions to check the created source
assert source.name == source_pydantic.name
assert source.description == source_pydantic.description
assert source.metadata_ == source_pydantic.metadata_
assert source.metadata == source_pydantic.metadata
assert source.organization_id == default_user.organization_id
@@ -2051,14 +2051,14 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul
source_pydantic = PydanticSource(
name=name,
description="This is a test source.",
metadata_={"type": "medical"},
metadata={"type": "medical"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source_pydantic = PydanticSource(
name=name,
description="This is a different test source.",
metadata_={"type": "legal"},
metadata={"type": "legal"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
@@ -2073,13 +2073,13 @@ def test_update_source(server: SyncServer, default_user):
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Update the source
update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata_={"type": "updated"})
update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata={"type": "updated"})
updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
# Assertions to verify update
assert updated_source.name == update_data.name
assert updated_source.description == update_data.description
assert updated_source.metadata_ == update_data.metadata_
assert updated_source.metadata == update_data.metadata
def test_delete_source(server: SyncServer, default_user):
@@ -2411,7 +2411,7 @@ def test_create_job(server: SyncServer, default_user):
"""Test creating a job."""
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": "test"},
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
@@ -2419,7 +2419,7 @@ def test_create_job(server: SyncServer, default_user):
# Assertions to ensure the created job matches the expected values
assert created_job.user_id == default_user.id
assert created_job.status == JobStatus.created
assert created_job.metadata_ == {"type": "test"}
assert created_job.metadata == {"type": "test"}
def test_get_job_by_id(server: SyncServer, default_user):
@@ -2427,7 +2427,7 @@ def test_get_job_by_id(server: SyncServer, default_user):
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": "test"},
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
@@ -2437,7 +2437,7 @@ def test_get_job_by_id(server: SyncServer, default_user):
# Assertions to ensure the fetched job matches the created job
assert fetched_job.id == created_job.id
assert fetched_job.status == JobStatus.created
assert fetched_job.metadata_ == {"type": "test"}
assert fetched_job.metadata == {"type": "test"}
def test_list_jobs(server: SyncServer, default_user):
@@ -2446,7 +2446,7 @@ def test_list_jobs(server: SyncServer, default_user):
for i in range(3):
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": f"test-{i}"},
metadata={"type": f"test-{i}"},
)
server.job_manager.create_job(job_data, actor=default_user)
@@ -2456,7 +2456,7 @@ def test_list_jobs(server: SyncServer, default_user):
# Assertions to check that the created jobs are listed
assert len(jobs) == 3
assert all(job.user_id == default_user.id for job in jobs)
assert all(job.metadata_["type"].startswith("test") for job in jobs)
assert all(job.metadata["type"].startswith("test") for job in jobs)
def test_update_job_by_id(server: SyncServer, default_user):
@@ -2464,17 +2464,18 @@ def test_update_job_by_id(server: SyncServer, default_user):
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": "test"},
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
assert created_job.metadata == {"type": "test"}
# Update the job
update_data = JobUpdate(status=JobStatus.completed, metadata_={"type": "updated"})
update_data = JobUpdate(status=JobStatus.completed, metadata={"type": "updated"})
updated_job = server.job_manager.update_job_by_id(created_job.id, update_data, actor=default_user)
# Assertions to ensure the job was updated
assert updated_job.status == JobStatus.completed
assert updated_job.metadata_ == {"type": "updated"}
assert updated_job.metadata == {"type": "updated"}
assert updated_job.completed_at is not None
@@ -2483,7 +2484,7 @@ def test_delete_job_by_id(server: SyncServer, default_user):
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": "test"},
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
@@ -2500,7 +2501,7 @@ def test_update_job_auto_complete(server: SyncServer, default_user):
# Create a job
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": "test"},
metadata={"type": "test"},
)
created_job = server.job_manager.create_job(job_data, actor=default_user)
@@ -2533,7 +2534,7 @@ def test_list_jobs_pagination(server: SyncServer, default_user):
for i in range(10):
job_data = PydanticJob(
status=JobStatus.created,
metadata_={"type": f"test-{i}"},
metadata={"type": f"test-{i}"},
)
server.job_manager.create_job(job_data, actor=default_user)
@@ -2550,15 +2551,15 @@ def test_list_jobs_by_status(server: SyncServer, default_user):
# Create multiple jobs with different statuses
job_data_created = PydanticJob(
status=JobStatus.created,
metadata_={"type": "test-created"},
metadata={"type": "test-created"},
)
job_data_in_progress = PydanticJob(
status=JobStatus.running,
metadata_={"type": "test-running"},
metadata={"type": "test-running"},
)
job_data_completed = PydanticJob(
status=JobStatus.completed,
metadata_={"type": "test-completed"},
metadata={"type": "test-completed"},
)
server.job_manager.create_job(job_data_created, actor=default_user)
@@ -2572,13 +2573,13 @@ def test_list_jobs_by_status(server: SyncServer, default_user):
# Assertions
assert len(created_jobs) == 1
assert created_jobs[0].metadata_["type"] == job_data_created.metadata_["type"]
assert created_jobs[0].metadata["type"] == job_data_created.metadata["type"]
assert len(in_progress_jobs) == 1
assert in_progress_jobs[0].metadata_["type"] == job_data_in_progress.metadata_["type"]
assert in_progress_jobs[0].metadata["type"] == job_data_in_progress.metadata["type"]
assert len(completed_jobs) == 1
assert completed_jobs[0].metadata_["type"] == job_data_completed.metadata_["type"]
assert completed_jobs[0].metadata["type"] == job_data_completed.metadata["type"]
def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job):

View File

@@ -930,6 +930,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
),
actor=actor,
)
assert source.created_by_id == user_id
# Create a test file with some content
test_file = tmp_path / "test.txt"
@@ -943,7 +944,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
job = server.job_manager.create_job(
PydanticJob(
user_id=user_id,
metadata_={"type": "embedding", "filename": test_file.name, "source_id": source.id},
metadata={"type": "embedding", "filename": test_file.name, "source_id": source.id},
),
actor=actor,
)
@@ -959,8 +960,8 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
# Verify job completed successfully
job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor)
assert job.status == "completed"
assert job.metadata_["num_passages"] == 1
assert job.metadata_["num_documents"] == 1
assert job.metadata["num_passages"] == 1
assert job.metadata["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
@@ -974,7 +975,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
job2 = server.job_manager.create_job(
PydanticJob(
user_id=user_id,
metadata_={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
metadata={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
),
actor=actor,
)
@@ -990,8 +991,8 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
# Verify second job completed successfully
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor)
assert job2.status == "completed"
assert job2.metadata_["num_passages"] >= 10
assert job2.metadata_["num_documents"] == 1
assert job2.metadata["num_passages"] >= 10
assert job2.metadata["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)