fix: Fix small benchmark bugs (#1826)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-03 12:08:10 -07:00
committed by GitHub
parent 58aa7e09df
commit 17192d5aa7

View File

@@ -2,11 +2,11 @@
import time
import uuid
from typing import Annotated
from typing import Annotated, Union
import typer
from letta import create_client
from letta import LocalClient, RESTClient, create_client
from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES
from letta.config import LettaConfig
@@ -17,11 +17,13 @@ from letta.utils import get_human_text, get_persona_text
app = typer.Typer()
def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES):
def send_message(
client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES
):
try:
print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}"
print(print_msg, end="\r", flush=True)
response = client.user_message(agent_id=agent_id, message=message, return_token_count=True)
response = client.user_message(agent_id=agent_id, message=message)
if turn + 1 == n_tries:
print(" " * len(print_msg), end="\r", flush=True)
@@ -65,7 +67,7 @@ def bench(
agent_id = agent.id
result, msg = send_message(
message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
)
if print_messages: