feat: move agent_id from query param to path variable and remove unus… (#1094)

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Robin Goetz
2024-03-05 23:28:03 +01:00
committed by GitHub
parent 1ada95ce5d
commit 16fa71b88e
5 changed files with 22 additions and 27 deletions

View File

@@ -126,13 +126,14 @@ def get_chat_completion(
# if hasattr(llm_wrapper, "supports_first_message"):
if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
prompt = llm_wrapper.chat_completion_to_prompt(
messages, functions if functions else [], first_message=first_message, function_documentation=documentation
messages=messages, functions=functions, first_message=first_message, function_documentation=documentation
)
else:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions if functions else [], function_documentation=documentation)
prompt = llm_wrapper.chat_completion_to_prompt(messages=messages, functions=functions, function_documentation=documentation)
printd(prompt)
except Exception as e:
print(e)
raise LocalLLMError(
f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}"
)

View File

@@ -1,11 +1,11 @@
import re
import uuid
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
@@ -20,7 +20,6 @@ class GetAgentRequest(BaseModel):
class AgentRenameRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
agent_name: str = Field(..., description="New name for the agent.")
@@ -51,9 +50,9 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
@router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
@@ -90,8 +89,9 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
sources=attached_sources,
)
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
@@ -100,8 +100,6 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
valid_name = validate_agent_name(request.agent_name)
interface.clear()
@@ -113,9 +111,9 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
raise HTTPException(status_code=500, detail=f"{e}")
return GetAgentResponse(agent_state=agent_state)
@router.delete("/agents", tags=["agents"])
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: str = Query(..., description="Unique identifier of the agent to be deleted."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""

View File

@@ -1,7 +1,7 @@
import uuid
from functools import partial
from fastapi import APIRouter, Depends, Body, Query
from fastapi import APIRouter, Depends, Body
from pydantic import BaseModel, Field
from memgpt.server.rest_api.auth_token import get_current_user
@@ -16,10 +16,6 @@ class CoreMemory(BaseModel):
persona: str | None = Field(None, description="Persona element of the core memory.")
class GetAgentMemoryRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
class GetAgentMemoryResponse(BaseModel):
core_memory: CoreMemory = Field(..., description="The state of the agent's core memory.")
recall_memory: int = Field(..., description="Size of the agent's recall memory.")
@@ -41,9 +37,9 @@ class UpdateAgentMemoryResponse(BaseModel):
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
def get_agent_memory(
agent_id: str = Query(..., description="The unique identifier of the agent."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
@@ -51,17 +47,13 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMemoryRequest(agent_id=agent_id)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
interface.clear()
memory = server.get_agent_memory(user_id=user_id, agent_id=agent_id)
return GetAgentMemoryResponse(**memory)
@router.post("/agents/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
def update_agent_memory(
agent_id: uuid.UUID,
request: UpdateAgentMemoryRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):

View File

@@ -1,7 +1,7 @@
import os
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.config import MemGPTConfig
from .utils import wipe_config
from .utils import wipe_config, create_config
from memgpt.server.server import SyncServer
import shutil
import uuid
@@ -9,6 +9,10 @@ import uuid
def test_migrate_0211():
wipe_config()
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")
data_dir = "tests/data/memgpt-0.2.11"
tmp_dir = f"tmp_{str(uuid.uuid4())}"

View File

@@ -17,12 +17,12 @@ agent_obj = None
def create_test_agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")
global client
client = create_client()
agent_state = client.create_agent(
name=test_agent_name,