chore: cleanup (#2480)
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"context_window": 128000,
|
||||
"model": "gpt-4o-mini",
|
||||
"model_endpoint_type": "azure",
|
||||
"model_wrapper": null
|
||||
}
|
||||
@@ -98,7 +98,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
# Special memory case
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||||
if not target_tool:
|
||||
return ToolExecutionResult(func_return=f"Tool not found: {tool_name}", success_flag=False)
|
||||
return ToolExecutionResult(status="error", func_return=f"Tool not found: {tool_name}")
|
||||
|
||||
try:
|
||||
if target_tool.name == "rethink_user_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
|
||||
|
||||
@@ -144,7 +144,7 @@ class DirectoryConnector(DataConnector):
|
||||
self.recursive = recursive
|
||||
self.extensions = extensions
|
||||
|
||||
if self.recursive == True:
|
||||
if self.recursive:
|
||||
assert self.input_directory is not None, "Must provide input directory if recursive is True."
|
||||
|
||||
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
||||
|
||||
@@ -133,11 +133,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None
|
||||
|
||||
return function_name, docstring
|
||||
|
||||
except Exception as e:
|
||||
raise LettaToolCreateError(f"Failed to parse function name and docstring: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise LettaToolCreateError(f"Name and docstring generation failed: {str(e)}")
|
||||
raise LettaToolCreateError(f"Failed to parse function name and docstring: {str(e)}")
|
||||
|
||||
@@ -107,7 +107,7 @@ def stringify_message(message: Message, use_assistant_name: bool = False) -> str
|
||||
elif message.role == "tool":
|
||||
if message.content:
|
||||
content = json.loads(message.content[0].text)
|
||||
if content["message"] != "None" and content["message"] != None:
|
||||
if content["message"] is not "None" and content["message"] is not None:
|
||||
return f"{assistant_name}: Tool call returned {content['message']}"
|
||||
return None
|
||||
elif message.role == "system":
|
||||
|
||||
@@ -16,11 +16,6 @@ def datetime_to_timestamp(dt):
|
||||
return int(dt.timestamp())
|
||||
|
||||
|
||||
def timestamp_to_datetime(ts):
|
||||
# convert integer timestamp to datetime object
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
|
||||
def get_local_time_military():
|
||||
# Get the current time in UTC
|
||||
current_time_utc = datetime.now(pytz.utc)
|
||||
|
||||
@@ -141,7 +141,8 @@ class ToolRulesSolver(BaseModel):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||
|
||||
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
||||
@staticmethod
|
||||
def validate_conditional_tool(rule: ConditionalToolRule):
|
||||
"""
|
||||
Validate a conditional tool rule
|
||||
|
||||
|
||||
@@ -69,8 +69,6 @@ class CommonSqlalchemyMetaMixins(Base):
|
||||
"""returns the user id for the specified property"""
|
||||
full_prop = f"_{prop}_by_id"
|
||||
prop_value = getattr(self, full_prop, None)
|
||||
if not prop_value:
|
||||
return
|
||||
return prop_value
|
||||
|
||||
def _user_id_setter(self, prop: str, value: str) -> None:
|
||||
|
||||
@@ -47,12 +47,6 @@ class OpenAIMessage(BaseModel):
|
||||
metadata: Optional[Dict] = Field(None, description="Metadata associated with the message.")
|
||||
|
||||
|
||||
class MessageFile(BaseModel):
|
||||
id: str
|
||||
object: str = "thread.message.file"
|
||||
created_at: int # unix timestamp
|
||||
|
||||
|
||||
class OpenAIThread(BaseModel):
|
||||
"""Represents an OpenAI thread (equivalent to Letta agent)"""
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class Provider(ProviderBase):
|
||||
return f"{base_name}/{model_name}"
|
||||
|
||||
def cast_to_subtype(self):
|
||||
match (self.provider_type):
|
||||
match self.provider_type:
|
||||
case ProviderType.letta:
|
||||
return LettaProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.openai:
|
||||
|
||||
@@ -1563,45 +1563,6 @@ class AgentManager:
|
||||
)
|
||||
return await self.append_to_in_context_messages_async([system_message], agent_id=agent_state.id, actor=actor)
|
||||
|
||||
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
|
||||
Args:
|
||||
actor:
|
||||
agent_id:
|
||||
new_memory (Memory): the new memory object to compare to the current memory object
|
||||
|
||||
Returns:
|
||||
modified (bool): whether the memory was updated
|
||||
"""
|
||||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
system_message = self.message_manager.get_message_by_id(message_id=agent_state.message_ids[0], actor=actor)
|
||||
if new_memory.compile() not in system_message.content[0].text:
|
||||
# update the blocks (LRW) in the DB
|
||||
for label in agent_state.memory.list_block_labels():
|
||||
updated_value = new_memory.get_block(label).value
|
||||
if updated_value != agent_state.memory.get_block(label).value:
|
||||
# update the block if it's changed
|
||||
block_id = agent_state.memory.get_block(label).id
|
||||
self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor)
|
||||
|
||||
# refresh memory from DB (using block ids)
|
||||
agent_state.memory = Memory(
|
||||
blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()],
|
||||
prompt_template=get_prompt_template_for_agent_type(agent_state.agent_type),
|
||||
)
|
||||
|
||||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||||
# rebuild memory - this records the last edited timestamp of the memory
|
||||
# TODO: pass in update timestamp from block edit time
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||||
@@ -1659,51 +1620,6 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
@enforce_types
|
||||
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the source to
|
||||
source_id: ID of the source to attach
|
||||
actor: User performing the action
|
||||
|
||||
Raises:
|
||||
ValueError: If either agent or source doesn't exist
|
||||
IntegrityError: If the source is already attached to the agent
|
||||
"""
|
||||
|
||||
with db_registry.session() as session:
|
||||
# Verify both agent and source exist and user has permission to access them
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# The _process_relationship helper already handles duplicate checking via unique constraint
|
||||
_process_relationship(
|
||||
session=session,
|
||||
agent=agent,
|
||||
relationship_name="sources",
|
||||
model_class=SourceModel,
|
||||
item_ids=[source_id],
|
||||
allow_partial=False,
|
||||
replace=False, # Extend existing sources rather than replace
|
||||
)
|
||||
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
# Force rebuild of system prompt so that the agent is updated with passage count
|
||||
# and recent passages and add system message alert to agent
|
||||
self.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
|
||||
self.append_system_message(
|
||||
agent_id=agent_id,
|
||||
content=DATA_SOURCE_ATTACH_ALERT,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return agent.to_pydantic()
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
|
||||
@@ -569,29 +569,6 @@ class MessageManager:
|
||||
results = result.scalars().all()
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""
|
||||
Efficiently deletes all messages associated with a given agent_id,
|
||||
while enforcing permission checks and avoiding any ORM‑level loads.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
# 1) verify the agent exists and the actor has access
|
||||
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# 2) issue a CORE DELETE against the mapped class
|
||||
stmt = (
|
||||
delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
|
||||
# 3) commit once
|
||||
session.commit()
|
||||
|
||||
# 4) return the number of rows deleted
|
||||
return result.rowcount
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
|
||||
@@ -207,7 +207,6 @@ if __name__ == "__main__":
|
||||
|
||||
# skip if exists
|
||||
model_formatted = args.model.replace("/", "-")
|
||||
model_formatted = args.model.replace("/", "-")
|
||||
baseline_formatted = args.baseline.replace("/", "-")
|
||||
filename = f"results/nested_kv/nested_kv_results_{baseline_formatted}_nesting_{args.nesting_levels}_model_{model_formatted}_seed_{args.seed}.json"
|
||||
if not args.rerun and os.path.exists(filename):
|
||||
|
||||
@@ -189,7 +189,7 @@ def test_web_search(
|
||||
) -> None:
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content=("Use the web search tool to find the latest news about San Francisco."),
|
||||
content="Use the web search tool to find the latest news about San Francisco.",
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
|
||||
agent_id = sarah_agent.id
|
||||
actor = default_user
|
||||
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
|
||||
await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=default_source.id, actor=actor)
|
||||
|
||||
# Create some source passages
|
||||
source_passages = []
|
||||
@@ -3551,7 +3551,7 @@ async def test_create_and_upsert_identity(server: SyncServer, default_user, even
|
||||
assert identity.identity_type == identity_create.identity_type
|
||||
assert identity.properties == identity_create.properties
|
||||
assert identity.agent_ids == []
|
||||
assert identity.project_id == None
|
||||
assert identity.project_id is None
|
||||
|
||||
with pytest.raises(UniqueConstraintViolationError):
|
||||
await server.identity_manager.create_identity_async(
|
||||
@@ -4896,7 +4896,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
assert msg.tool_call.name == "custom_tool"
|
||||
|
||||
|
||||
def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent):
|
||||
def test_get_run_messages_with_assistant_message(server: SyncServer, default_user: PydanticUser, sarah_agent):
|
||||
"""Test getting messages for a run with request config."""
|
||||
# Create a run with custom request config
|
||||
run = server.job_manager.create_job(
|
||||
|
||||
@@ -239,7 +239,7 @@ async def test_round_robin(server, actor, participant_agents):
|
||||
assert group.manager_type == ManagerType.round_robin
|
||||
assert group.description == description
|
||||
assert group.agent_ids == [agent.id for agent in participant_agents]
|
||||
assert group.max_turns == None
|
||||
assert group.max_turns is None
|
||||
assert group.manager_agent_id is None
|
||||
assert group.termination_token is None
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ def test_parse_number_cases(strict_parser):
|
||||
def test_parse_boolean_true(strict_parser):
|
||||
assert strict_parser.parse("true") is True, "Should parse 'true'."
|
||||
# Check leftover
|
||||
assert strict_parser.last_parse_reminding == None, "No extra tokens expected."
|
||||
assert strict_parser.last_parse_reminding is None, "No extra tokens expected."
|
||||
|
||||
|
||||
def test_parse_boolean_false(strict_parser):
|
||||
|
||||
@@ -1143,7 +1143,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str,
|
||||
for step_id in step_ids:
|
||||
step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == None
|
||||
assert step.provider_id is None
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
assert step.model == agent.llm_config.model
|
||||
assert step.context_window_limit == agent.llm_config.context_window
|
||||
|
||||
Reference in New Issue
Block a user