chore: rename metadata_ field to metadata in pydantic (#732)
This commit is contained in:
@@ -101,7 +101,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"client.jobs.get(job_id=job.id).metadata_"
|
||||
"client.jobs.get(job_id=job.id).metadata"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1091,7 +1091,7 @@ def save_agent(agent: Agent):
|
||||
embedding_config=agent_state.embedding_config,
|
||||
message_ids=agent_state.message_ids,
|
||||
description=agent_state.description,
|
||||
metadata_=agent_state.metadata_,
|
||||
metadata=agent_state.metadata,
|
||||
# TODO: Add this back in later
|
||||
# tool_exec_environment_variables=agent_state.get_agent_env_vars_as_dict(),
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ class ChatOnlyAgent(Agent):
|
||||
memory=offline_memory,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
tool_ids=self.agent_state.metadata_.get("offline_memory_tools", []),
|
||||
tool_ids=self.agent_state.metadata.get("offline_memory_tools", []),
|
||||
include_base_tools=False,
|
||||
)
|
||||
self.offline_memory_agent.memory.update_block_value(label="conversation_block", value=recent_convo)
|
||||
|
||||
@@ -586,7 +586,7 @@ class RESTClient(AbstractClient):
|
||||
# create agent
|
||||
create_params = {
|
||||
"description": description,
|
||||
"metadata_": metadata,
|
||||
"metadata": metadata,
|
||||
"memory_blocks": [],
|
||||
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
|
||||
"tool_ids": tool_ids,
|
||||
@@ -685,7 +685,7 @@ class RESTClient(AbstractClient):
|
||||
tool_ids=tool_ids,
|
||||
tags=tags,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
metadata=metadata,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
message_ids=message_ids,
|
||||
@@ -2331,7 +2331,7 @@ class LocalClient(AbstractClient):
|
||||
# Create the base parameters
|
||||
create_params = {
|
||||
"description": description,
|
||||
"metadata_": metadata,
|
||||
"metadata": metadata,
|
||||
"memory_blocks": [],
|
||||
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
|
||||
"tool_ids": tool_ids,
|
||||
@@ -2423,7 +2423,7 @@ class LocalClient(AbstractClient):
|
||||
tool_ids=tool_ids,
|
||||
tags=tags,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
metadata=metadata,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
message_ids=message_ids,
|
||||
@@ -3100,7 +3100,7 @@ class LocalClient(AbstractClient):
|
||||
job = Job(
|
||||
user_id=self.user_id,
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "embedding", "filename": filename, "source_id": source_id},
|
||||
metadata={"type": "embedding", "filename": filename, "source_id": source_id},
|
||||
)
|
||||
job = self.server.job_manager.create_job(pydantic_job=job, actor=self.user)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ def load_data(connector: DataConnector, source: Source, passage_manager: Passage
|
||||
text=passage_text,
|
||||
file_id=file_metadata.id,
|
||||
source_id=source.id,
|
||||
metadata_=passage_metadata,
|
||||
metadata=passage_metadata,
|
||||
organization_id=source.organization_id,
|
||||
embedding_config=source.embedding_config,
|
||||
embedding=embedding,
|
||||
|
||||
@@ -113,14 +113,14 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"description": self.description,
|
||||
"message_ids": self.message_ids,
|
||||
"tools": self.tools,
|
||||
"sources": self.sources,
|
||||
"sources": [source.to_pydantic() for source in self.sources],
|
||||
"tags": [t.tag for t in self.tags],
|
||||
"tool_rules": self.tool_rules,
|
||||
"system": self.system,
|
||||
"agent_type": self.agent_type,
|
||||
"llm_config": self.llm_config,
|
||||
"embedding_config": self.embedding_config,
|
||||
"metadata_": self.metadata_,
|
||||
"metadata": self.metadata_,
|
||||
"memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
|
||||
"created_by_id": self.created_by_id,
|
||||
"last_updated_by_id": self.last_updated_by_id,
|
||||
|
||||
@@ -45,7 +45,9 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
Schema = Persona
|
||||
case _:
|
||||
Schema = PydanticBlock
|
||||
return Schema.model_validate(self)
|
||||
model_dict = {k: v for k, v in self.__dict__.items() if k in self.__pydantic_model__.model_fields}
|
||||
model_dict["metadata"] = self.metadata_
|
||||
return Schema.model_validate(model_dict)
|
||||
|
||||
|
||||
@event.listens_for(Block, "after_update") # Changed from 'before_update'
|
||||
|
||||
@@ -449,6 +449,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
def to_pydantic(self) -> "BaseModel":
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
if hasattr(self, "metadata_"):
|
||||
model_dict = {k: v for k, v in self.__dict__.items() if k in self.__pydantic_model__.model_fields}
|
||||
model_dict["metadata"] = self.metadata_
|
||||
return self.__pydantic_model__.model_validate(model_dict)
|
||||
|
||||
return self.__pydantic_model__.model_validate(self)
|
||||
|
||||
def to_record(self) -> "BaseModel":
|
||||
|
||||
@@ -72,7 +72,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.")
|
||||
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
|
||||
|
||||
memory: Memory = Field(..., description="The in-context memory of the agent.")
|
||||
tools: List[Tool] = Field(..., description="The tools used by the agent.")
|
||||
@@ -122,7 +122,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
|
||||
)
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="The LLM configuration handle used by the agent, specified in the format "
|
||||
@@ -203,7 +203,7 @@ class UpdateAgent(BaseModel):
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
|
||||
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
|
||||
None, description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
||||
|
||||
# metadata
|
||||
description: Optional[str] = Field(None, description="Description of the block.")
|
||||
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
|
||||
metadata: Optional[dict] = Field({}, description="Metadata of the block.")
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self.value)
|
||||
@@ -63,7 +63,7 @@ class Block(BaseBlock):
|
||||
label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block.
|
||||
template_name (str): The name of the block template (if it is a template).
|
||||
description (str): Description of the block.
|
||||
metadata_ (Dict): Metadata of the block.
|
||||
metadata (Dict): Metadata of the block.
|
||||
user_id (str): The unique identifier of the user associated with the block.
|
||||
"""
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class JobBase(OrmMetadataBase):
|
||||
__id_prefix__ = "job"
|
||||
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
||||
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
||||
metadata_: Optional[dict] = Field(None, description="The metadata of the job.")
|
||||
metadata: Optional[dict] = Field(None, description="The metadata of the job.")
|
||||
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
||||
|
||||
|
||||
|
||||
@@ -88,6 +88,12 @@ class LettaBase(BaseModel):
|
||||
return f"{cls.__id_prefix__}-{v}"
|
||||
return v
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs):
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm and "metadata" in data:
|
||||
data["metadata_"] = data.pop("metadata")
|
||||
return data
|
||||
|
||||
|
||||
class OrmMetadataBase(LettaBase):
|
||||
# metadata fields
|
||||
|
||||
@@ -23,7 +23,7 @@ class PassageBase(OrmMetadataBase):
|
||||
|
||||
# file association
|
||||
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
|
||||
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
||||
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
||||
|
||||
|
||||
class Passage(PassageBase):
|
||||
|
||||
@@ -24,7 +24,7 @@ class Source(BaseSource):
|
||||
name (str): The name of the source.
|
||||
embedding_config (EmbeddingConfig): The embedding configuration used by the source.
|
||||
user_id (str): The ID of the user that created the source.
|
||||
metadata_ (dict): Metadata associated with the source.
|
||||
metadata (dict): Metadata associated with the source.
|
||||
description (str): The description of the source.
|
||||
"""
|
||||
|
||||
@@ -33,7 +33,7 @@ class Source(BaseSource):
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
|
||||
organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.")
|
||||
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
|
||||
# metadata fields
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
@@ -54,7 +54,7 @@ class SourceCreate(BaseSource):
|
||||
|
||||
# optional
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
|
||||
|
||||
class SourceUpdate(BaseSource):
|
||||
@@ -64,5 +64,5 @@ class SourceUpdate(BaseSource):
|
||||
|
||||
name: Optional[str] = Field(None, description="The name of the source.")
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
|
||||
|
||||
@@ -556,7 +556,7 @@ async def process_message_background(
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.completed,
|
||||
completed_at=datetime.utcnow(),
|
||||
metadata_={"result": result.model_dump()}, # Store the result in metadata
|
||||
metadata={"result": result.model_dump()}, # Store the result in metadata
|
||||
)
|
||||
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
||||
|
||||
@@ -568,7 +568,7 @@ async def process_message_background(
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.failed,
|
||||
completed_at=datetime.utcnow(),
|
||||
metadata_={"error": str(e)},
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
server.job_manager.update_job_by_id(job_id=job_id, job_update=job_update, actor=actor)
|
||||
raise
|
||||
@@ -596,7 +596,7 @@ async def send_message_async(
|
||||
run = Run(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.created,
|
||||
metadata_={
|
||||
metadata={
|
||||
"job_type": "send_message_async",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
|
||||
@@ -26,9 +26,9 @@ def list_jobs(
|
||||
jobs = server.job_manager.list_jobs(actor=actor)
|
||||
|
||||
if source_id:
|
||||
# can't be in the ORM since we have source_id stored in the metadata_
|
||||
# can't be in the ORM since we have source_id stored in the metadata
|
||||
# TODO: Probably change this
|
||||
jobs = [job for job in jobs if job.metadata_.get("source_id") == source_id]
|
||||
jobs = [job for job in jobs if job.metadata.get("source_id") == source_id]
|
||||
return jobs
|
||||
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ def upload_file_to_source(
|
||||
# create job
|
||||
job = Job(
|
||||
user_id=actor.id,
|
||||
metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id},
|
||||
metadata={"type": "embedding", "filename": file.filename, "source_id": source_id},
|
||||
completed_at=None,
|
||||
)
|
||||
job_id = job.id
|
||||
|
||||
@@ -956,8 +956,8 @@ class SyncServer(Server):
|
||||
|
||||
# update job status
|
||||
job.status = JobStatus.completed
|
||||
job.metadata_["num_passages"] = num_passages
|
||||
job.metadata_["num_documents"] = num_documents
|
||||
job.metadata["num_passages"] = num_passages
|
||||
job.metadata["num_documents"] = num_documents
|
||||
self.job_manager.update_job_by_id(job_id=job_id, job_update=JobUpdate(**job.model_dump()), actor=actor)
|
||||
|
||||
# update all agents who have this source attached
|
||||
@@ -1019,7 +1019,7 @@ class SyncServer(Server):
|
||||
attached_agents = [{"id": agent.id, "name": agent.name} for agent in agents]
|
||||
|
||||
# Overwrite metadata field, should be empty anyways
|
||||
source.metadata_ = dict(
|
||||
source.metadata = dict(
|
||||
num_documents=num_documents,
|
||||
num_passages=num_passages,
|
||||
attached_agents=attached_agents,
|
||||
|
||||
@@ -82,7 +82,7 @@ class AgentManager:
|
||||
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
||||
if agent_create.memory_blocks:
|
||||
for create_block in agent_create.memory_blocks:
|
||||
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor)
|
||||
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor)
|
||||
block_ids.append(block.id)
|
||||
|
||||
# TODO: Remove this block once we deprecate the legacy `tools` field
|
||||
@@ -117,7 +117,7 @@ class AgentManager:
|
||||
source_ids=agent_create.source_ids or [],
|
||||
tags=agent_create.tags or [],
|
||||
description=agent_create.description,
|
||||
metadata_=agent_create.metadata_,
|
||||
metadata=agent_create.metadata,
|
||||
tool_rules=agent_create.tool_rules,
|
||||
actor=actor,
|
||||
)
|
||||
@@ -177,7 +177,7 @@ class AgentManager:
|
||||
source_ids: List[str],
|
||||
tags: List[str],
|
||||
description: Optional[str] = None,
|
||||
metadata_: Optional[Dict] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
tool_rules: Optional[List[PydanticToolRule]] = None,
|
||||
) -> PydanticAgentState:
|
||||
"""Create a new agent."""
|
||||
@@ -191,7 +191,7 @@ class AgentManager:
|
||||
"embedding_config": embedding_config,
|
||||
"organization_id": actor.organization_id,
|
||||
"description": description,
|
||||
"metadata_": metadata_,
|
||||
"metadata_": metadata,
|
||||
"tool_rules": tool_rules,
|
||||
}
|
||||
|
||||
@@ -242,11 +242,14 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Update scalar fields directly
|
||||
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"}
|
||||
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata"}
|
||||
for field in scalar_fields:
|
||||
value = getattr(agent_update, field, None)
|
||||
if value is not None:
|
||||
setattr(agent, field, value)
|
||||
if field == "metadata":
|
||||
setattr(agent, "metadata_", value)
|
||||
else:
|
||||
setattr(agent, field, value)
|
||||
|
||||
# Update relationships using _process_relationship and _process_tags
|
||||
if agent_update.tool_ids is not None:
|
||||
|
||||
@@ -24,11 +24,11 @@ class BlockManager:
|
||||
"""Create a new block based on the Block schema."""
|
||||
db_block = self.get_block_by_id(block.id, actor)
|
||||
if db_block:
|
||||
update_data = BlockUpdate(**block.model_dump(exclude_none=True))
|
||||
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
|
||||
self.update_block(block.id, update_data, actor)
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
data = block.model_dump(exclude_none=True)
|
||||
data = block.model_dump(to_orm=True, exclude_none=True)
|
||||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||||
block.create(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
@@ -40,7 +40,7 @@ class BlockManager:
|
||||
|
||||
with self.session_maker() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
update_data = block_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(block, key, value)
|
||||
|
||||
@@ -39,7 +39,7 @@ class JobManager:
|
||||
with self.session_maker() as session:
|
||||
# Associate the job with the user
|
||||
pydantic_job.user_id = actor.id
|
||||
job_data = pydantic_job.model_dump()
|
||||
job_data = pydantic_job.model_dump(to_orm=True)
|
||||
job = JobModel(**job_data)
|
||||
job.create(session, actor=actor) # Save job in the database
|
||||
return job.to_pydantic()
|
||||
@@ -52,7 +52,7 @@ class JobManager:
|
||||
job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
|
||||
# Update job attributes with only the fields that were explicitly set
|
||||
update_data = job_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Automatically update the completion timestamp if status is set to 'completed'
|
||||
if update_data.get("status") == JobStatus.completed and not job.completed_at:
|
||||
@@ -62,7 +62,9 @@ class JobManager:
|
||||
setattr(job, key, value)
|
||||
|
||||
# Save the updated job to the database
|
||||
return job.update(db_session=session) # TODO: Add this later , actor=actor)
|
||||
job.update(db_session=session) # TODO: Add this later , actor=actor)
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
|
||||
@@ -44,7 +44,7 @@ class OrganizationManager:
|
||||
@enforce_types
|
||||
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
with self.session_maker() as session:
|
||||
org = OrganizationModel(**pydantic_org.model_dump())
|
||||
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
|
||||
org.create(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
|
||||
@@ -38,14 +38,14 @@ class PassageManager:
|
||||
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
# Common fields for both passage types
|
||||
data = pydantic_passage.model_dump()
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata_", {}),
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.utcnow()),
|
||||
}
|
||||
@@ -145,7 +145,7 @@ class PassageManager:
|
||||
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class ProviderManager:
|
||||
# Lazily create the provider id prior to persistence
|
||||
provider.resolve_identifier()
|
||||
|
||||
new_provider = ProviderModel(**provider.model_dump(exclude_unset=True))
|
||||
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
||||
new_provider.create(session)
|
||||
return new_provider.to_pydantic()
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProviderManager:
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
|
||||
|
||||
# Update only the fields that are provided in ProviderUpdate
|
||||
update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(existing_provider, key, value)
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class SandboxConfigManager:
|
||||
return db_env_var
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
env_var = SandboxEnvVarModel(**env_var.model_dump(exclude_none=True))
|
||||
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
|
||||
env_var.create(session, actor=actor)
|
||||
return env_var.to_pydantic()
|
||||
|
||||
@@ -183,7 +183,7 @@ class SandboxConfigManager:
|
||||
"""Update an existing sandbox environment variable."""
|
||||
with self.session_maker() as session:
|
||||
env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor)
|
||||
update_data = env_var_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value}
|
||||
|
||||
if update_data:
|
||||
|
||||
@@ -30,7 +30,7 @@ class SourceManager:
|
||||
with self.session_maker() as session:
|
||||
# Provide default embedding config if not given
|
||||
source.organization_id = actor.organization_id
|
||||
source = SourceModel(**source.model_dump(exclude_none=True))
|
||||
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
|
||||
source.create(session, actor=actor)
|
||||
return source.to_pydantic()
|
||||
|
||||
@@ -41,7 +41,7 @@ class SourceManager:
|
||||
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
|
||||
|
||||
# get update dictionary
|
||||
update_data = source_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value}
|
||||
|
||||
@@ -132,7 +132,7 @@ class SourceManager:
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
file_metadata.organization_id = actor.organization_id
|
||||
file_metadata = FileMetadataModel(**file_metadata.model_dump(exclude_none=True))
|
||||
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
|
||||
file_metadata.create(session, actor=actor)
|
||||
return file_metadata.to_pydantic()
|
||||
|
||||
|
||||
@@ -364,7 +364,9 @@ class ToolExecutionSandbox:
|
||||
sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash})
|
||||
else:
|
||||
# no template
|
||||
sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"}))
|
||||
sbx = Sandbox(
|
||||
metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(to_orm=True, exclude={"pip_requirements"})
|
||||
)
|
||||
|
||||
# install pip requirements
|
||||
if e2b_config.pip_requirements:
|
||||
|
||||
@@ -39,7 +39,7 @@ class ToolManager:
|
||||
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = pydantic_tool.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# If there's anything to update
|
||||
if update_data:
|
||||
@@ -67,7 +67,7 @@ class ToolManager:
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump()
|
||||
tool_data = pydantic_tool.model_dump(to_orm=True)
|
||||
|
||||
tool = ToolModel(**tool_data)
|
||||
tool.create(session, actor=actor) # Re-raise other database-related errors
|
||||
@@ -112,7 +112,7 @@ class ToolManager:
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = tool_update.model_dump(exclude_none=True)
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(tool, key, value)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class UserManager:
|
||||
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
|
||||
"""Create a new user if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
new_user = UserModel(**pydantic_user.model_dump())
|
||||
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
|
||||
new_user.create(session)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
@@ -57,7 +57,7 @@ class UserManager:
|
||||
existing_user = UserModel.read(db_session=session, identifier=user_update.id)
|
||||
|
||||
# Update only the fields that are provided in UserUpdate
|
||||
update_data = user_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(existing_user, key, value)
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ from letta.schemas.source import Source
|
||||
def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job:
|
||||
# load a file into a source (non-blocking job)
|
||||
upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata)
|
||||
|
||||
# view active jobs
|
||||
active_jobs = client.list_active_jobs()
|
||||
jobs = client.list_jobs()
|
||||
assert upload_job.id in [j.id for j in jobs]
|
||||
assert len(active_jobs) == 1
|
||||
assert active_jobs[0].metadata_["source_id"] == source.id
|
||||
assert active_jobs[0].metadata["source_id"] == source.id
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 240
|
||||
|
||||
@@ -96,7 +96,7 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
||||
# Assert scalar fields
|
||||
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
|
||||
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
|
||||
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
|
||||
assert agent.metadata == request.metadata, f"Metadata mismatch: {agent.metadata} != {request.metadata}"
|
||||
|
||||
# Assert agent env vars
|
||||
if hasattr(request, "tool_exec_environment_variables"):
|
||||
|
||||
@@ -466,8 +466,8 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
assert len(sources) == 1
|
||||
|
||||
# TODO: add back?
|
||||
assert sources[0].metadata_["num_passages"] == 0
|
||||
assert sources[0].metadata_["num_documents"] == 0
|
||||
assert sources[0].metadata["num_passages"] == 0
|
||||
assert sources[0].metadata["num_documents"] == 0
|
||||
|
||||
# update the source
|
||||
original_id = source.id
|
||||
@@ -491,7 +491,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
filename = "tests/data/memgpt_paper.pdf"
|
||||
upload_job = upload_file_using_client(client, source, filename)
|
||||
job = client.get_job(upload_job.id)
|
||||
created_passages = job.metadata_["num_passages"]
|
||||
created_passages = job.metadata["num_passages"]
|
||||
|
||||
# TODO: add test for blocking job
|
||||
|
||||
@@ -515,8 +515,8 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# check number of passages
|
||||
sources = client.list_sources()
|
||||
# TODO: add back?
|
||||
# assert sources.sources[0].metadata_["num_passages"] > 0
|
||||
# assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
|
||||
# assert sources.sources[0].metadata["num_passages"] > 0
|
||||
# assert sources.sources[0].metadata["num_documents"] == 0 # TODO: fix this once document store added
|
||||
print(sources)
|
||||
|
||||
# detach the source
|
||||
|
||||
@@ -134,7 +134,7 @@ def default_source(server: SyncServer, default_user):
|
||||
source_pydantic = PydanticSource(
|
||||
name="Test Source",
|
||||
description="This is a test source.",
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
@@ -146,7 +146,7 @@ def other_source(server: SyncServer, default_user):
|
||||
source_pydantic = PydanticSource(
|
||||
name="Another Test Source",
|
||||
description="This is yet another test source.",
|
||||
metadata_={"type": "another_test"},
|
||||
metadata={"type": "another_test"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
@@ -235,7 +235,7 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -253,7 +253,7 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -273,7 +273,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -291,7 +291,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -348,7 +348,7 @@ def default_block(server: SyncServer, default_user):
|
||||
value="Default Block Content",
|
||||
description="A default test block",
|
||||
limit=1000,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
block = block_manager.create_or_update_block(block_data, actor=default_user)
|
||||
yield block
|
||||
@@ -363,7 +363,7 @@ def other_block(server: SyncServer, default_user):
|
||||
value="Other Block Content",
|
||||
description="Another test block",
|
||||
limit=500,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
block = block_manager.create_or_update_block(block_data, actor=default_user)
|
||||
yield block
|
||||
@@ -444,7 +444,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
source_ids=[default_source.id],
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
metadata_={"test_key": "test_value"},
|
||||
metadata={"test_key": "test_value"},
|
||||
tool_rules=[InitToolRule(tool_name=print_tool.name)],
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||
@@ -600,7 +600,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(model_name="letta"),
|
||||
message_ids=["10", "20"],
|
||||
metadata_={"train_key": "train_value"},
|
||||
metadata={"train_key": "train_value"},
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "a", "new_tool_exec_key": "n"},
|
||||
)
|
||||
|
||||
@@ -1938,7 +1938,7 @@ def test_create_block(server: SyncServer, default_user):
|
||||
template_name="sample_template",
|
||||
description="A test block",
|
||||
limit=1000,
|
||||
metadata_={"example": "data"},
|
||||
metadata={"example": "data"},
|
||||
)
|
||||
|
||||
block = block_manager.create_or_update_block(block_create, actor=default_user)
|
||||
@@ -1950,7 +1950,7 @@ def test_create_block(server: SyncServer, default_user):
|
||||
assert block.template_name == block_create.template_name
|
||||
assert block.description == block_create.description
|
||||
assert block.limit == block_create.limit
|
||||
assert block.metadata_ == block_create.metadata_
|
||||
assert block.metadata == block_create.metadata
|
||||
assert block.organization_id == default_user.organization_id
|
||||
|
||||
|
||||
@@ -2033,7 +2033,7 @@ def test_create_source(server: SyncServer, default_user):
|
||||
source_pydantic = PydanticSource(
|
||||
name="Test Source",
|
||||
description="This is a test source.",
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
@@ -2041,7 +2041,7 @@ def test_create_source(server: SyncServer, default_user):
|
||||
# Assertions to check the created source
|
||||
assert source.name == source_pydantic.name
|
||||
assert source.description == source_pydantic.description
|
||||
assert source.metadata_ == source_pydantic.metadata_
|
||||
assert source.metadata == source_pydantic.metadata
|
||||
assert source.organization_id == default_user.organization_id
|
||||
|
||||
|
||||
@@ -2051,14 +2051,14 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul
|
||||
source_pydantic = PydanticSource(
|
||||
name=name,
|
||||
description="This is a test source.",
|
||||
metadata_={"type": "medical"},
|
||||
metadata={"type": "medical"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
source_pydantic = PydanticSource(
|
||||
name=name,
|
||||
description="This is a different test source.",
|
||||
metadata_={"type": "legal"},
|
||||
metadata={"type": "legal"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
@@ -2073,13 +2073,13 @@ def test_update_source(server: SyncServer, default_user):
|
||||
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
|
||||
# Update the source
|
||||
update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata_={"type": "updated"})
|
||||
update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata={"type": "updated"})
|
||||
updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
|
||||
|
||||
# Assertions to verify update
|
||||
assert updated_source.name == update_data.name
|
||||
assert updated_source.description == update_data.description
|
||||
assert updated_source.metadata_ == update_data.metadata_
|
||||
assert updated_source.metadata == update_data.metadata
|
||||
|
||||
|
||||
def test_delete_source(server: SyncServer, default_user):
|
||||
@@ -2411,7 +2411,7 @@ def test_create_job(server: SyncServer, default_user):
|
||||
"""Test creating a job."""
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
|
||||
created_job = server.job_manager.create_job(job_data, actor=default_user)
|
||||
@@ -2419,7 +2419,7 @@ def test_create_job(server: SyncServer, default_user):
|
||||
# Assertions to ensure the created job matches the expected values
|
||||
assert created_job.user_id == default_user.id
|
||||
assert created_job.status == JobStatus.created
|
||||
assert created_job.metadata_ == {"type": "test"}
|
||||
assert created_job.metadata == {"type": "test"}
|
||||
|
||||
|
||||
def test_get_job_by_id(server: SyncServer, default_user):
|
||||
@@ -2427,7 +2427,7 @@ def test_get_job_by_id(server: SyncServer, default_user):
|
||||
# Create a job
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
created_job = server.job_manager.create_job(job_data, actor=default_user)
|
||||
|
||||
@@ -2437,7 +2437,7 @@ def test_get_job_by_id(server: SyncServer, default_user):
|
||||
# Assertions to ensure the fetched job matches the created job
|
||||
assert fetched_job.id == created_job.id
|
||||
assert fetched_job.status == JobStatus.created
|
||||
assert fetched_job.metadata_ == {"type": "test"}
|
||||
assert fetched_job.metadata == {"type": "test"}
|
||||
|
||||
|
||||
def test_list_jobs(server: SyncServer, default_user):
|
||||
@@ -2446,7 +2446,7 @@ def test_list_jobs(server: SyncServer, default_user):
|
||||
for i in range(3):
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": f"test-{i}"},
|
||||
metadata={"type": f"test-{i}"},
|
||||
)
|
||||
server.job_manager.create_job(job_data, actor=default_user)
|
||||
|
||||
@@ -2456,7 +2456,7 @@ def test_list_jobs(server: SyncServer, default_user):
|
||||
# Assertions to check that the created jobs are listed
|
||||
assert len(jobs) == 3
|
||||
assert all(job.user_id == default_user.id for job in jobs)
|
||||
assert all(job.metadata_["type"].startswith("test") for job in jobs)
|
||||
assert all(job.metadata["type"].startswith("test") for job in jobs)
|
||||
|
||||
|
||||
def test_update_job_by_id(server: SyncServer, default_user):
|
||||
@@ -2464,17 +2464,18 @@ def test_update_job_by_id(server: SyncServer, default_user):
|
||||
# Create a job
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
created_job = server.job_manager.create_job(job_data, actor=default_user)
|
||||
assert created_job.metadata == {"type": "test"}
|
||||
|
||||
# Update the job
|
||||
update_data = JobUpdate(status=JobStatus.completed, metadata_={"type": "updated"})
|
||||
update_data = JobUpdate(status=JobStatus.completed, metadata={"type": "updated"})
|
||||
updated_job = server.job_manager.update_job_by_id(created_job.id, update_data, actor=default_user)
|
||||
|
||||
# Assertions to ensure the job was updated
|
||||
assert updated_job.status == JobStatus.completed
|
||||
assert updated_job.metadata_ == {"type": "updated"}
|
||||
assert updated_job.metadata == {"type": "updated"}
|
||||
assert updated_job.completed_at is not None
|
||||
|
||||
|
||||
@@ -2483,7 +2484,7 @@ def test_delete_job_by_id(server: SyncServer, default_user):
|
||||
# Create a job
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
created_job = server.job_manager.create_job(job_data, actor=default_user)
|
||||
|
||||
@@ -2500,7 +2501,7 @@ def test_update_job_auto_complete(server: SyncServer, default_user):
|
||||
# Create a job
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "test"},
|
||||
metadata={"type": "test"},
|
||||
)
|
||||
created_job = server.job_manager.create_job(job_data, actor=default_user)
|
||||
|
||||
@@ -2533,7 +2534,7 @@ def test_list_jobs_pagination(server: SyncServer, default_user):
|
||||
for i in range(10):
|
||||
job_data = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": f"test-{i}"},
|
||||
metadata={"type": f"test-{i}"},
|
||||
)
|
||||
server.job_manager.create_job(job_data, actor=default_user)
|
||||
|
||||
@@ -2550,15 +2551,15 @@ def test_list_jobs_by_status(server: SyncServer, default_user):
|
||||
# Create multiple jobs with different statuses
|
||||
job_data_created = PydanticJob(
|
||||
status=JobStatus.created,
|
||||
metadata_={"type": "test-created"},
|
||||
metadata={"type": "test-created"},
|
||||
)
|
||||
job_data_in_progress = PydanticJob(
|
||||
status=JobStatus.running,
|
||||
metadata_={"type": "test-running"},
|
||||
metadata={"type": "test-running"},
|
||||
)
|
||||
job_data_completed = PydanticJob(
|
||||
status=JobStatus.completed,
|
||||
metadata_={"type": "test-completed"},
|
||||
metadata={"type": "test-completed"},
|
||||
)
|
||||
|
||||
server.job_manager.create_job(job_data_created, actor=default_user)
|
||||
@@ -2572,13 +2573,13 @@ def test_list_jobs_by_status(server: SyncServer, default_user):
|
||||
|
||||
# Assertions
|
||||
assert len(created_jobs) == 1
|
||||
assert created_jobs[0].metadata_["type"] == job_data_created.metadata_["type"]
|
||||
assert created_jobs[0].metadata["type"] == job_data_created.metadata["type"]
|
||||
|
||||
assert len(in_progress_jobs) == 1
|
||||
assert in_progress_jobs[0].metadata_["type"] == job_data_in_progress.metadata_["type"]
|
||||
assert in_progress_jobs[0].metadata["type"] == job_data_in_progress.metadata["type"]
|
||||
|
||||
assert len(completed_jobs) == 1
|
||||
assert completed_jobs[0].metadata_["type"] == job_data_completed.metadata_["type"]
|
||||
assert completed_jobs[0].metadata["type"] == job_data_completed.metadata["type"]
|
||||
|
||||
|
||||
def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job):
|
||||
|
||||
@@ -930,6 +930,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
assert source.created_by_id == user_id
|
||||
|
||||
# Create a test file with some content
|
||||
test_file = tmp_path / "test.txt"
|
||||
@@ -943,7 +944,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
job = server.job_manager.create_job(
|
||||
PydanticJob(
|
||||
user_id=user_id,
|
||||
metadata_={"type": "embedding", "filename": test_file.name, "source_id": source.id},
|
||||
metadata={"type": "embedding", "filename": test_file.name, "source_id": source.id},
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
@@ -959,8 +960,8 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
# Verify job completed successfully
|
||||
job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor)
|
||||
assert job.status == "completed"
|
||||
assert job.metadata_["num_passages"] == 1
|
||||
assert job.metadata_["num_documents"] == 1
|
||||
assert job.metadata["num_passages"] == 1
|
||||
assert job.metadata["num_documents"] == 1
|
||||
|
||||
# Verify passages were added
|
||||
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
@@ -974,7 +975,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
job2 = server.job_manager.create_job(
|
||||
PydanticJob(
|
||||
user_id=user_id,
|
||||
metadata_={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
|
||||
metadata={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
@@ -990,8 +991,8 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
# Verify second job completed successfully
|
||||
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor)
|
||||
assert job2.status == "completed"
|
||||
assert job2.metadata_["num_passages"] >= 10
|
||||
assert job2.metadata_["num_documents"] == 1
|
||||
assert job2.metadata["num_passages"] >= 10
|
||||
assert job2.metadata["num_documents"] == 1
|
||||
|
||||
# Verify passages were appended (not replaced)
|
||||
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
|
||||
Reference in New Issue
Block a user