fix: Fix small benchmark bugs (#1826)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -2,11 +2,11 @@
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Union
|
||||
|
||||
import typer
|
||||
|
||||
from letta import create_client
|
||||
from letta import LocalClient, RESTClient, create_client
|
||||
from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES
|
||||
from letta.config import LettaConfig
|
||||
|
||||
@@ -17,11 +17,13 @@ from letta.utils import get_human_text, get_persona_text
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES):
|
||||
def send_message(
|
||||
client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES
|
||||
):
|
||||
try:
|
||||
print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}"
|
||||
print(print_msg, end="\r", flush=True)
|
||||
response = client.user_message(agent_id=agent_id, message=message, return_token_count=True)
|
||||
response = client.user_message(agent_id=agent_id, message=message)
|
||||
|
||||
if turn + 1 == n_tries:
|
||||
print(" " * len(print_msg), end="\r", flush=True)
|
||||
@@ -65,7 +67,7 @@ def bench(
|
||||
|
||||
agent_id = agent.id
|
||||
result, msg = send_message(
|
||||
message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
|
||||
client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
|
||||
)
|
||||
|
||||
if print_messages:
|
||||
|
||||
Reference in New Issue
Block a user