fix: Fix memgpt benchmark command (#1041)
This commit is contained in:
@@ -19,22 +19,22 @@ def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: boo
|
||||
try:
|
||||
print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}"
|
||||
print(print_msg, end="\r", flush=True)
|
||||
response, tokens_accumulated = client.user_message(agent_id=agent_id, message=message, return_token_count=True)
|
||||
response = client.user_message(agent_id=agent_id, message=message, return_token_count=True)
|
||||
|
||||
if turn + 1 == n_tries:
|
||||
print(" " * len(print_msg), end="\r", flush=True)
|
||||
|
||||
for r in response:
|
||||
if "function_call" in r and fn_type in r["function_call"] and any("assistant_message" in re for re in response):
|
||||
return True, r["function_call"], tokens_accumulated
|
||||
return True, r["function_call"]
|
||||
|
||||
return False, "No function called.", tokens_accumulated
|
||||
return False, "No function called."
|
||||
except LLMJSONParsingError as e:
|
||||
print(f"Error in parsing MemGPT JSON: {e}")
|
||||
return False, "Failed to decode valid MemGPT JSON from LLM output.", tokens_accumulated
|
||||
return False, "Failed to decode valid MemGPT JSON from LLM output."
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
return False, "An unexpected error occurred.", tokens_accumulated
|
||||
return False, "An unexpected error occurred."
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -55,10 +55,10 @@ def bench(
|
||||
bench_id = uuid.uuid4()
|
||||
|
||||
for i in range(n_tries):
|
||||
agent = client.create_agent(agent_config={"name": f"benchmark_{bench_id}_agent_{i}", "persona": PERSONA, "human": HUMAN})
|
||||
agent = client.create_agent(name=f"benchmark_{bench_id}_agent_{i}", persona=PERSONA, human=HUMAN)
|
||||
|
||||
agent_id = agent.id
|
||||
result, msg, tokens_accumulated = send_message(
|
||||
result, msg = send_message(
|
||||
message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
|
||||
)
|
||||
|
||||
@@ -68,7 +68,8 @@ def bench(
|
||||
if result:
|
||||
score += 1
|
||||
|
||||
total_tokens_accumulated += tokens_accumulated
|
||||
# TODO: add back once we start tracking usage via the client
|
||||
# total_tokens_accumulated += tokens_accumulated
|
||||
|
||||
elapsed_time_run = round(time.time() - start_time_run, 2)
|
||||
print(f"Score for {fn_type}: {score}/{n_tries}, took {elapsed_time_run} seconds")
|
||||
@@ -84,5 +85,6 @@ def bench(
|
||||
print(f"HUMAN: {config.human}")
|
||||
|
||||
print(
|
||||
f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds at average of {round(total_tokens_accumulated/elapsed_time, 2)} t/s\n"
|
||||
# f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds at average of {round(total_tokens_accumulated/elapsed_time, 2)} t/s\n"
|
||||
f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds\n"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user