From 16fa71b88e97a808bed3fc0cd1bf91c0c6ef8b60 Mon Sep 17 00:00:00 2001 From: Robin Goetz <35136007+goetzrobin@users.noreply.github.com> Date: Tue, 5 Mar 2024 23:28:03 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20move=20agent=5Fid=20from=20query=20para?= =?UTF-8?q?m=20to=20path=20variable=20and=20remove=20unus=E2=80=A6=20(#109?= =?UTF-8?q?4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: cpacker --- memgpt/local_llm/chat_completion_proxy.py | 5 +++-- memgpt/server/rest_api/agents/config.py | 18 ++++++++---------- memgpt/server/rest_api/agents/memory.py | 18 +++++------------- tests/test_migrate.py | 6 +++++- tests/test_summarize.py | 2 +- 5 files changed, 22 insertions(+), 27 deletions(-) diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 0dd938d1..56bb98fa 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -126,13 +126,14 @@ def get_chat_completion( # if hasattr(llm_wrapper, "supports_first_message"): if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message: prompt = llm_wrapper.chat_completion_to_prompt( - messages, functions if functions else [], first_message=first_message, function_documentation=documentation + messages=messages, functions=functions, first_message=first_message, function_documentation=documentation ) else: - prompt = llm_wrapper.chat_completion_to_prompt(messages, functions if functions else [], function_documentation=documentation) + prompt = llm_wrapper.chat_completion_to_prompt(messages=messages, functions=functions, function_documentation=documentation) printd(prompt) except Exception as e: + print(e) raise LocalLLMError( f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}" ) diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index d299b5f2..dc895a91 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -1,11 +1,11 @@ import re import uuid from functools import partial +from typing import List, Optional -from fastapi import APIRouter, Body, Depends, Query, HTTPException, status +from fastapi import APIRouter, Body, Depends, HTTPException, status from fastapi.responses import JSONResponse from pydantic import BaseModel, Field -from typing import List, Optional from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel from memgpt.server.rest_api.auth_token import get_current_user @@ -20,7 +20,6 @@ class GetAgentRequest(BaseModel): class AgentRenameRequest(BaseModel): - agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.") agent_name: str = Field(..., description="New name for the agent.") @@ -51,9 +50,9 @@ def validate_agent_name(name: str) -> str: def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents", tags=["agents"], response_model=GetAgentResponse) + @router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse) def get_agent_config( - agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."), + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ @@ -90,8 +89,9 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, sources=attached_sources, ) - @router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse) + @router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse) def update_agent_name( + agent_id: uuid.UUID, request: AgentRenameRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): @@ -100,8 +100,6 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, This changes the name of the agent in the database but does NOT edit the agent's persona. """ - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None - valid_name = validate_agent_name(request.agent_name) interface.clear() @@ -113,9 +111,9 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, raise HTTPException(status_code=500, detail=f"{e}") return GetAgentResponse(agent_state=agent_state) - @router.delete("/agents", tags=["agents"]) + @router.delete("/agents/{agent_id}", tags=["agents"]) def delete_agent( - agent_id: str = Query(..., description="Unique identifier of the agent to be deleted."), + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ diff --git a/memgpt/server/rest_api/agents/memory.py b/memgpt/server/rest_api/agents/memory.py index 599c8c50..a3a972ad 100644 --- a/memgpt/server/rest_api/agents/memory.py +++ b/memgpt/server/rest_api/agents/memory.py @@ -1,7 +1,7 @@ import uuid from functools import partial -from fastapi import APIRouter, Depends, Body, Query +from fastapi import APIRouter, Depends, Body from pydantic import BaseModel, Field from memgpt.server.rest_api.auth_token import get_current_user @@ -16,10 +16,6 @@ class CoreMemory(BaseModel): persona: str | None = Field(None, description="Persona element of the core memory.") -class GetAgentMemoryRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") - - class GetAgentMemoryResponse(BaseModel): core_memory: CoreMemory = Field(..., description="The state of the agent's core memory.") recall_memory: int = Field(..., description="Size of the agent's recall memory.") @@ -41,9 +37,9 @@ class UpdateAgentMemoryResponse(BaseModel): def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/memory", tags=["agents"], response_model=GetAgentMemoryResponse) + @router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=GetAgentMemoryResponse) def get_agent_memory( - agent_id: str = Query(..., description="The unique identifier of the agent."), + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ @@ -51,17 +47,13 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. """ - # Validate with the Pydantic model (optional) - request = GetAgentMemoryRequest(agent_id=agent_id) - - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None - interface.clear() memory = server.get_agent_memory(user_id=user_id, agent_id=agent_id) return GetAgentMemoryResponse(**memory) - @router.post("/agents/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse) + @router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse) def update_agent_memory( + agent_id: uuid.UUID, request: UpdateAgentMemoryRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): diff --git a/tests/test_migrate.py b/tests/test_migrate.py index fa9dec66..37a5aa8d 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,7 +1,7 @@ import os from memgpt.migrate import migrate_all_agents, migrate_all_sources from memgpt.config import MemGPTConfig -from .utils import wipe_config +from .utils import wipe_config, create_config from memgpt.server.server import SyncServer import shutil import uuid @@ -9,6 +9,10 @@ import uuid def test_migrate_0211(): wipe_config() + if os.getenv("OPENAI_API_KEY"): + create_config("openai") + else: + create_config("memgpt_hosted") data_dir = "tests/data/memgpt-0.2.11" tmp_dir = f"tmp_{str(uuid.uuid4())}" diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 82f617bc..76e9fffe 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -17,12 +17,12 @@ agent_obj = None def create_test_agent(): """Create a test agent that we can call functions on""" wipe_config() - global client if os.getenv("OPENAI_API_KEY"): create_config("openai") else: create_config("memgpt_hosted") + global client client = create_client() agent_state = client.create_agent( name=test_agent_name,