diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index f8bf5b5e..3dee356f 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -28,88 +28,6 @@ from letta_client.types import ( from letta.schemas.agent import AgentState from letta.schemas.llm_config import LLMConfig -# ------------------------------ -# Fixtures -# ------------------------------ - - -@pytest.fixture(scope="module") -def server_url() -> str: - """ - Provides the URL for the Letta server. - If LETTA_SERVER_URL is not set, starts the server in a background thread - and polls until it’s accepting connections. - """ - - def _run_server() -> None: - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - - # Poll until the server is up (or timeout) - timeout_seconds = 30 - deadline = time.time() + timeout_seconds - while time.time() < deadline: - try: - resp = requests.get(url + "/v1/health") - if resp.status_code < 500: - break - except requests.exceptions.RequestException: - pass - time.sleep(0.1) - else: - raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") - - return url - - -@pytest.fixture(scope="module") -def client(server_url: str) -> Letta: - """ - Creates and returns a synchronous Letta REST client for testing. - """ - client_instance = Letta(base_url=server_url) - yield client_instance - - -@pytest.fixture(scope="function") -def async_client(server_url: str) -> AsyncLetta: - """ - Creates and returns an asynchronous Letta REST client for testing. - """ - async_client_instance = AsyncLetta(base_url=server_url) - yield async_client_instance - - -@pytest.fixture(scope="module") -def agent_state(client: Letta) -> AgentState: - """ - Creates and returns an agent state for testing with a pre-configured agent. - The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. - """ - client.tools.upsert_base_tools() - - send_message_tool = client.tools.list(name="send_message")[0] - agent_state_instance = client.agents.create( - name="supervisor", - include_base_tools=False, - tool_ids=[send_message_tool.id], - model="openai/gpt-4o", - embedding="letta/letta-free", - tags=["supervisor"], - ) - yield agent_state_instance - - client.agents.delete(agent_state_instance.id) - - # ------------------------------ # Helper Functions and Constants # ------------------------------ @@ -175,7 +93,7 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [ ] all_configs = [ "openai-gpt-4o-mini.json", - # "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop + "azure-gpt-4o-mini.json", "claude-3-5-sonnet.json", "claude-3-7-sonnet.json", "claude-3-7-sonnet-extended.json", @@ -377,19 +295,6 @@ def accumulate_chunks(chunks: List[Any]) -> List[Any]: return [m for m in messages if m is not None] -def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: - start = time.time() - while True: - run = client.runs.retrieve(run_id) - if run.status == "completed": - return run - if run.status == "failed": - raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") - if time.time() - start > timeout: - raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") - time.sleep(interval) - - def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: """ Asserts that a list of message dictionaries contains the expected types and statuses. @@ -406,6 +311,108 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: assert messages[1]["message_type"] == "assistant_message" +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="module") +def client(server_url: str) -> Letta: + """ + Creates and returns a synchronous Letta REST client for testing. + """ + client_instance = Letta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +def async_client(server_url: str) -> AsyncLetta: + """ + Creates and returns an asynchronous Letta REST client for testing. + """ + async_client_instance = AsyncLetta(base_url=server_url) + yield async_client_instance + + +@pytest.fixture(scope="module") +def agent_state(client: Letta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. + """ + client.tools.upsert_base_tools() + dice_tool = client.tools.upsert_from_function(func=roll_dice) + + send_message_tool = client.tools.list(name="send_message")[0] + agent_state_instance = client.agents.create( + name="supervisor", + include_base_tools=False, + tool_ids=[send_message_tool.id, dice_tool.id], + model="openai/gpt-4o", + embedding="letta/letta-free", + tags=["supervisor"], + ) + yield agent_state_instance + + client.agents.delete(agent_state_instance.id) + + +@pytest.fixture(scope="module") +def agent_state_no_tools(client: Letta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is named 'supervisor' and is configured with no tools. + """ + send_message_tool = client.tools.list(name="send_message")[0] + agent_state_instance = client.agents.create( + name="supervisor", + include_base_tools=False, + model="openai/gpt-4o", + embedding="letta/letta-free", + tags=["supervisor"], + ) + yield agent_state_instance + + client.agents.delete(agent_state_instance.id) + + # ------------------------------ # Test Cases # ------------------------------ @@ -479,8 +486,6 @@ def test_tool_call( Tests sending a message with a synchronous client. Verifies that the response messages follow the expected order. """ - dice_tool = client.tools.upsert_from_function(func=roll_dice) - client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id) last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( @@ -552,24 +557,22 @@ def test_base64_image_input( def test_agent_loop_error( disable_e2b_api_key: Any, client: Letta, - agent_state: AgentState, + agent_state_no_tools: AgentState, llm_config: LLMConfig, ) -> None: """ Tests sending a message with a synchronous client. Verifies that no new messages are persisted on error. """ - last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) - tools = agent_state.tools - agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[]) + last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1) + agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config) with pytest.raises(ApiError): client.agents.messages.create( - agent_id=agent_state.id, + agent_id=agent_state_no_tools.id, messages=USER_MESSAGE_FORCE_REPLY, ) - messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id) assert len(messages_from_db) == 0 - client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools]) @pytest.mark.parametrize( @@ -593,8 +596,7 @@ def test_step_streaming_greeting_with_assistant_message( agent_id=agent_state.id, messages=USER_MESSAGE_FORCE_REPLY, ) - chunks = list(response) - messages = accumulate_chunks(chunks) + messages = accumulate_chunks(list(response)) assert_greeting_with_assistant_message_response(messages, streaming=True) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True) @@ -622,8 +624,7 @@ def test_step_streaming_greeting_without_assistant_message( messages=USER_MESSAGE_FORCE_REPLY, use_assistant_message=False, ) - chunks = list(response) - messages = accumulate_chunks(chunks) + messages = accumulate_chunks(list(response)) assert_greeting_without_assistant_message_response(messages, streaming=True) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) assert_greeting_without_assistant_message_response(messages_from_db, from_db=True) @@ -644,16 +645,13 @@ def test_step_streaming_tool_call( Tests sending a streaming message with a synchronous client. Checks that each chunk in the stream has the correct message types. """ - dice_tool = client.tools.upsert_from_function(func=roll_dice) - agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id) last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, ) - chunks = list(response) - messages = accumulate_chunks(chunks) + messages = accumulate_chunks(list(response)) assert_tool_call_response(messages, streaming=True) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_tool_call_response(messages_from_db, from_db=True) @@ -667,26 +665,24 @@ def test_step_streaming_tool_call( def test_step_stream_agent_loop_error( disable_e2b_api_key: Any, client: Letta, - agent_state: AgentState, + agent_state_no_tools: AgentState, llm_config: LLMConfig, ) -> None: """ Tests sending a message with a synchronous client. Verifies that no new messages are persisted on error. """ - last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) - tools = agent_state.tools - agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[]) + last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1) + agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config) with pytest.raises(ApiError): response = client.agents.messages.create_stream( - agent_id=agent_state.id, + agent_id=agent_state_no_tools.id, messages=USER_MESSAGE_FORCE_REPLY, ) list(response) - messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id) assert len(messages_from_db) == 0 - client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools]) @pytest.mark.parametrize( @@ -711,8 +707,7 @@ def test_token_streaming_greeting_with_assistant_message( messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=True, ) - chunks = list(response) - messages = accumulate_chunks(chunks) + messages = accumulate_chunks(list(response)) assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_greeting_with_assistant_message_response(messages_from_db, from_db=True) @@ -741,8 +736,7 @@ def test_token_streaming_greeting_without_assistant_message( use_assistant_message=False, stream_tokens=True, ) - chunks = list(response) - messages = accumulate_chunks(chunks) + messages = accumulate_chunks(list(response)) assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False) assert_greeting_without_assistant_message_response(messages_from_db, from_db=True) @@ -763,8 +757,6 @@ def test_token_streaming_tool_call( Tests sending a streaming message with a synchronous client. Checks that each chunk in the stream has the correct message types. """ - dice_tool = client.tools.upsert_from_function(func=roll_dice) - agent_state = client.agents.tools.attach(agent_id=agent_state.id, tool_id=dice_tool.id) last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( @@ -772,8 +764,7 @@ def test_token_streaming_tool_call( messages=USER_MESSAGE_ROLL_DICE, stream_tokens=True, ) - chunks = list(response) - messages = accumulate_chunks(chunks) + messages = accumulate_chunks(list(response)) assert_tool_call_response(messages, streaming=True) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) assert_tool_call_response(messages_from_db, from_db=True) @@ -787,19 +778,18 @@ def test_token_streaming_tool_call( def test_token_streaming_agent_loop_error( disable_e2b_api_key: Any, client: Letta, - agent_state: AgentState, + agent_state_no_tools: AgentState, llm_config: LLMConfig, ) -> None: """ Tests sending a message with a synchronous client. Verifies that no new messages are persisted on error. """ - last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1) - tools = agent_state.tools - agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config, tool_ids=[]) + last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1) + agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config, tool_ids=[]) try: response = client.agents.messages.create_stream( - agent_id=agent_state.id, + agent_id=agent_state_no_tools.id, messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=True, ) @@ -807,9 +797,21 @@ def test_token_streaming_agent_loop_error( except: pass # only some models throw an error TODO: make this consistent - messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id) assert len(messages_from_db) == 0 - client.agents.modify(agent_id=agent_state.id, tool_ids=[t.id for t in tools]) + + +def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: + start = time.time() + while True: + run = client.runs.retrieve(run_id) + if run.status == "completed": + return run + if run.status == "failed": + raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") + time.sleep(interval) @pytest.mark.parametrize( @@ -850,7 +852,6 @@ def test_async_greeting_with_assistant_message( def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLMConfig): """Test that summarization is automatically triggered.""" llm_config.context_window = 3000 - client.tools.upsert_base_tools() send_message_tool = client.tools.list(name="send_message")[0] temp_agent_state = client.agents.create(