chore: cleanup (#2480)

This commit is contained in:
Andy Li
2025-05-29 10:40:41 -07:00
committed by GitHub
parent c54447a985
commit 831b7d5862
18 changed files with 14 additions and 143 deletions

View File

@@ -1,6 +0,0 @@
{
"context_window": 128000,
"model": "gpt-4o-mini",
"model_endpoint_type": "azure",
"model_wrapper": null
}

View File

@@ -98,7 +98,7 @@ class VoiceSleeptimeAgent(LettaAgent):
# Special memory case
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
return ToolExecutionResult(func_return=f"Tool not found: {tool_name}", success_flag=False)
return ToolExecutionResult(status="error", func_return=f"Tool not found: {tool_name}")
try:
if target_tool.name == "rethink_user_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:

View File

@@ -144,7 +144,7 @@ class DirectoryConnector(DataConnector):
self.recursive = recursive
self.extensions = extensions
if self.recursive == True:
if self.recursive:
assert self.input_directory is not None, "Must provide input directory if recursive is True."
def find_files(self, source: Source) -> Iterator[FileMetadata]:

View File

@@ -133,11 +133,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None
return function_name, docstring
except Exception as e:
raise LettaToolCreateError(f"Failed to parse function name and docstring: {str(e)}")
except Exception as e:
import traceback
traceback.print_exc()
raise LettaToolCreateError(f"Name and docstring generation failed: {str(e)}")
raise LettaToolCreateError(f"Failed to parse function name and docstring: {str(e)}")

View File

@@ -107,7 +107,7 @@ def stringify_message(message: Message, use_assistant_name: bool = False) -> str
elif message.role == "tool":
if message.content:
content = json.loads(message.content[0].text)
if content["message"] != "None" and content["message"] != None:
if content["message"] is not "None" and content["message"] is not None:
return f"{assistant_name}: Tool call returned {content['message']}"
return None
elif message.role == "system":

View File

@@ -16,11 +16,6 @@ def datetime_to_timestamp(dt):
return int(dt.timestamp())
def timestamp_to_datetime(ts):
# convert integer timestamp to datetime object
return datetime.fromtimestamp(ts)
def get_local_time_military():
# Get the current time in UTC
current_time_utc = datetime.now(pytz.utc)

View File

@@ -141,7 +141,8 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def validate_conditional_tool(self, rule: ConditionalToolRule):
@staticmethod
def validate_conditional_tool(rule: ConditionalToolRule):
"""
Validate a conditional tool rule

View File

@@ -69,8 +69,6 @@ class CommonSqlalchemyMetaMixins(Base):
"""returns the user id for the specified property"""
full_prop = f"_{prop}_by_id"
prop_value = getattr(self, full_prop, None)
if not prop_value:
return
return prop_value
def _user_id_setter(self, prop: str, value: str) -> None:

View File

@@ -47,12 +47,6 @@ class OpenAIMessage(BaseModel):
metadata: Optional[Dict] = Field(None, description="Metadata associated with the message.")
class MessageFile(BaseModel):
id: str
object: str = "thread.message.file"
created_at: int # unix timestamp
class OpenAIThread(BaseModel):
"""Represents an OpenAI thread (equivalent to Letta agent)"""

View File

@@ -86,7 +86,7 @@ class Provider(ProviderBase):
return f"{base_name}/{model_name}"
def cast_to_subtype(self):
match (self.provider_type):
match self.provider_type:
case ProviderType.letta:
return LettaProvider(**self.model_dump(exclude_none=True))
case ProviderType.openai:

View File

@@ -1563,45 +1563,6 @@ class AgentManager:
)
return await self.append_to_in_context_messages_async([system_message], agent_id=agent_state.id, actor=actor)
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
@trace_method
@enforce_types
def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
"""
Update internal memory object and system prompt if there have been modifications.
Args:
actor:
agent_id:
new_memory (Memory): the new memory object to compare to the current memory object
Returns:
modified (bool): whether the memory was updated
"""
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
system_message = self.message_manager.get_message_by_id(message_id=agent_state.message_ids[0], actor=actor)
if new_memory.compile() not in system_message.content[0].text:
# update the blocks (LRW) in the DB
for label in agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value
if updated_value != agent_state.memory.get_block(label).value:
# update the block if it's changed
block_id = agent_state.memory.get_block(label).id
self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor)
# refresh memory from DB (using block ids)
agent_state.memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()],
prompt_template=get_prompt_template_for_agent_type(agent_state.agent_type),
)
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor)
return agent_state
@trace_method
@enforce_types
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
@@ -1659,51 +1620,6 @@ class AgentManager:
# ======================================================================================================================
# Source Management
# ======================================================================================================================
@trace_method
@enforce_types
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a source to an agent.
Args:
agent_id: ID of the agent to attach the source to
source_id: ID of the source to attach
actor: User performing the action
Raises:
ValueError: If either agent or source doesn't exist
IntegrityError: If the source is already attached to the agent
"""
with db_registry.session() as session:
# Verify both agent and source exist and user has permission to access them
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# The _process_relationship helper already handles duplicate checking via unique constraint
_process_relationship(
session=session,
agent=agent,
relationship_name="sources",
model_class=SourceModel,
item_ids=[source_id],
allow_partial=False,
replace=False, # Extend existing sources rather than replace
)
# Commit the changes
agent.update(session, actor=actor)
# Force rebuild of system prompt so that the agent is updated with passage count
# and recent passages and add system message alert to agent
self.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
self.append_system_message(
agent_id=agent_id,
content=DATA_SOURCE_ATTACH_ALERT,
actor=actor,
)
return agent.to_pydantic()
@trace_method
@enforce_types
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:

View File

@@ -569,29 +569,6 @@ class MessageManager:
results = result.scalars().all()
return [msg.to_pydantic() for msg in results]
@enforce_types
@trace_method
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
"""
Efficiently deletes all messages associated with a given agent_id,
while enforcing permission checks and avoiding any ORMlevel loads.
"""
with db_registry.session() as session:
# 1) verify the agent exists and the actor has access
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# 2) issue a CORE DELETE against the mapped class
stmt = (
delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id)
)
result = session.execute(stmt)
# 3) commit once
session.commit()
# 4) return the number of rows deleted
return result.rowcount
@enforce_types
@trace_method
async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser) -> int:

View File

@@ -207,7 +207,6 @@ if __name__ == "__main__":
# skip if exists
model_formatted = args.model.replace("/", "-")
model_formatted = args.model.replace("/", "-")
baseline_formatted = args.baseline.replace("/", "-")
filename = f"results/nested_kv/nested_kv_results_{baseline_formatted}_nesting_{args.nesting_levels}_model_{model_formatted}_seed_{args.seed}.json"
if not args.rerun and os.path.exists(filename):

View File

@@ -189,7 +189,7 @@ def test_web_search(
) -> None:
user_message = MessageCreate(
role="user",
content=("Use the web search tool to find the latest news about San Francisco."),
content="Use the web search tool to find the latest news about San Francisco.",
otid=USER_MESSAGE_OTID,
)

View File

@@ -513,7 +513,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
agent_id = sarah_agent.id
actor = default_user
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
await server.agent_manager.attach_source_async(agent_id=agent_id, source_id=default_source.id, actor=actor)
# Create some source passages
source_passages = []
@@ -3551,7 +3551,7 @@ async def test_create_and_upsert_identity(server: SyncServer, default_user, even
assert identity.identity_type == identity_create.identity_type
assert identity.properties == identity_create.properties
assert identity.agent_ids == []
assert identity.project_id == None
assert identity.project_id is None
with pytest.raises(UniqueConstraintViolationError):
await server.identity_manager.create_identity_async(
@@ -4896,7 +4896,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
assert msg.tool_call.name == "custom_tool"
def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent):
def test_get_run_messages_with_assistant_message(server: SyncServer, default_user: PydanticUser, sarah_agent):
"""Test getting messages for a run with request config."""
# Create a run with custom request config
run = server.job_manager.create_job(

View File

@@ -239,7 +239,7 @@ async def test_round_robin(server, actor, participant_agents):
assert group.manager_type == ManagerType.round_robin
assert group.description == description
assert group.agent_ids == [agent.id for agent in participant_agents]
assert group.max_turns == None
assert group.max_turns is None
assert group.manager_agent_id is None
assert group.termination_token is None

View File

@@ -96,7 +96,7 @@ def test_parse_number_cases(strict_parser):
def test_parse_boolean_true(strict_parser):
assert strict_parser.parse("true") is True, "Should parse 'true'."
# Check leftover
assert strict_parser.last_parse_reminding == None, "No extra tokens expected."
assert strict_parser.last_parse_reminding is None, "No extra tokens expected."
def test_parse_boolean_false(strict_parser):

View File

@@ -1143,7 +1143,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str,
for step_id in step_ids:
step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
assert step, "Step was not logged correctly"
assert step.provider_id == None
assert step.provider_id is None
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window