feat: Add data sources to REST API (#1118)

This commit is contained in:
Sarah Wooders
2024-03-10 14:34:35 -07:00
committed by GitHub
parent e93d41b57a
commit dc958bcd9e
13 changed files with 419 additions and 51 deletions

View File

@@ -115,12 +115,11 @@ def load_directory(
document_store=None,
passage_store=passage_storage,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
ms.delete_source(source_id=source.id)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except ValueError as e:
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
raise

View File

@@ -3,13 +3,14 @@ import requests
import uuid
from typing import Dict, List, Union, Optional, Tuple
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.metadata import MetadataStore
from memgpt.data_sources.connectors import DataConnector
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
@@ -69,6 +70,30 @@ class AbstractClient(object):
def save(self):
raise NotImplementedError
def list_sources(self):
"""List loaded sources"""
raise NotImplementedError
def delete_source(self):
"""Delete a source and associated data (including attached to agents)"""
raise NotImplementedError
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
"""Load {filename} and insert into source"""
raise NotImplementedError
def create_source(self, name: str):
"""Create a new source"""
raise NotImplementedError
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Attach a source to an agent"""
raise NotImplementedError
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Detach a source from an agent"""
raise NotImplementedError
class RESTClient(AbstractClient):
def __init__(
@@ -127,17 +152,17 @@ class RESTClient(AbstractClient):
)
return agent_state
def delete_agent(self, agent_id: str):
response = requests.delete(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
return agent_id
def delete_agent(self, agent_id: uuid.UUID):
response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete agent: {response.text}"
def create_preset(self, preset: Preset):
raise NotImplementedError
def get_agent_config(self, agent_id: str) -> AgentState:
def get_agent_config(self, agent_id: uuid.UUID) -> AgentState:
raise NotImplementedError
def get_agent_memory(self, agent_id: str) -> Dict:
def get_agent_memory(self, agent_id: uuid.UUID) -> Dict:
raise NotImplementedError
def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> Dict:
@@ -157,6 +182,53 @@ class RESTClient(AbstractClient):
def save(self):
raise NotImplementedError
def list_sources(self):
"""List loaded sources"""
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
response_json = response.json()
return response_json
def delete_source(self, source_id: uuid.UUID):
"""Delete a source and associated data (including attached to agents)"""
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete source: {response.text}"
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
"""Load {filename} and insert into source"""
params = {"source_id": str(source_id)}
files = {"file": open(filename, "rb")}
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
return response.json()
def create_source(self, name: str) -> Source:
"""Create a new source"""
payload = {"name": name}
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
response_json = response.json()
print("CREATE SOURCE", response_json, response.text)
return Source(
id=uuid.UUID(response_json["id"]),
name=response_json["name"],
user_id=uuid.UUID(response_json["user_id"]),
created_at=datetime.datetime.fromtimestamp(response_json["created_at"]),
embedding_dim=response_json["embedding_config"]["embedding_dim"],
embedding_model=response_json["embedding_config"]["embedding_model"],
)
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
"""Attach a source to an agent"""
params = {"source_name": source_name, "agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
return response.json()
def detach_source(self, source_name: str, agent_id: uuid.UUID):
"""Detach a source from an agent"""
params = {"source_name": source_name, "agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return response.json()
class LocalClient(AbstractClient):
def __init__(
@@ -267,3 +339,15 @@ class LocalClient(AbstractClient):
def save(self):
self.server.save_agents()
def load_data(self, connector: DataConnector, source_name: str):
self.server.load_data(user_id=self.user_id, connector=connector, source_name=source_name)
def create_source(self, name: str):
self.server.create_source(user_id=self.user_id, name=name)
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)
def delete_agent(self, agent_id: uuid.UUID):
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)

View File

@@ -5,6 +5,7 @@ from memgpt.embeddings import embedding_model
from memgpt.data_types import Document, Passage
from typing import List, Iterator, Dict, Tuple, Optional
import typer
from llama_index.core import Document as LlamaIndexDocument
@@ -53,7 +54,15 @@ def load_data(
# generate passages
for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
embedding = embed_model.get_text_embedding(passage_text)
try:
embedding = embed_model.get_text_embedding(passage_text)
except Exception as e:
typer.secho(
f"Warning: Failed to get embedding for {passage_text} (error: {str(e)}), skipping insert into VectorDB.",
fg=typer.colors.YELLOW,
)
continue
passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
text=passage_text,

View File

@@ -1,8 +1,11 @@
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field, Json, ConfigDict
import uuid
import base64
import numpy as np
from datetime import datetime
from sqlmodel import Field, SQLModel
from sqlalchemy import JSON, Column, BINARY, TypeDecorator
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_human_text, get_persona_text, printd
@@ -83,3 +86,38 @@ class PersonaModel(SQLModel, table=True):
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")
class SourceModel(SQLModel, table=True):
name: str = Field(..., description="The name of the source.")
description: str = Field(None, description="The description of the source.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the source was created.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
# embedding info
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
class PassageModel(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
text: str = Field(..., description="The text of the passage.")
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
data_source: Optional[str] = Field(None, description="The data source of the passage.")
doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
class DocumentModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.")
text: str = Field(..., description="The text of the document.")
data_source: str = Field(..., description="The data source of the document.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")

View File

@@ -100,10 +100,8 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
raise HTTPException(status_code=500, detail=f"{e}")
return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token)
@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(
user_id: str = Query(..., description="The ID of the user to be deleted."),
):
@router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(user_id):
# TODO make a soft deletion, instead of a hard deletion
try:
user_id_uuid = uuid.UUID(user_id)

View File

@@ -1,11 +1,11 @@
import re
import uuid
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
@@ -20,6 +20,7 @@ class GetAgentRequest(BaseModel):
class AgentRenameRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
agent_name: str = Field(..., description="New name for the agent.")
@@ -50,9 +51,9 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse)
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: uuid.UUID,
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
@@ -68,8 +69,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
interface.clear()
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
# return GetAgentResponse(agent_state=agent_state)
LLMConfigModel(**vars(agent_state.llm_config))
EmbeddingConfigModel(**vars(agent_state.embedding_config))
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return GetAgentResponse(
agent_state=AgentStateModel(
@@ -89,9 +90,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
sources=attached_sources,
)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
@@ -100,6 +100,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
valid_name = validate_agent_name(request.agent_name)
interface.clear()
@@ -113,15 +115,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: uuid.UUID,
agent_id,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
request = GetAgentRequest(agent_id=agent_id)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
agent_id = uuid.UUID(agent_id)
interface.clear()
try:

View File

@@ -21,6 +21,7 @@ from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_ass
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.server import SyncServer
"""
@@ -94,6 +95,7 @@ app.include_router(setup_humans_index_router(server, interface, password), prefi
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)

View File

@@ -0,0 +1,165 @@
import uuid
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.data_types import Source
from memgpt.data_sources.connectors import DirectoryConnector
router = APIRouter()
"""
Implement the following functions:
* List all available sources
* Create a new source
* Delete a source
* Upload a file to a server that is loaded into a specific source
* Paginated get all passages from a source
* Paginated get all documents from a source
* Attach a source to an agent
"""
class ListSourcesResponse(BaseModel):
sources: List[SourceModel] = Field(..., description="List of available sources")
class CreateSourceRequest(BaseModel):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
class CreateSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The newly created source.")
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
class UploadFileToSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")
class GetSourcePassagesResponse(BaseModel):
passages: List[PassageModel] = Field(..., description="List of passages from the source.")
class GetSourceDocumentsResponse(BaseModel):
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def list_source(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# Clear the interface
interface.clear()
sources = server.ms.list_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)
@router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source(
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
try:
server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/attach", tags=["sources"], response_model=SourceModel)
async def attach_source_to_agent(
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
source_name: str = Query(..., description="The name of the source to attach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at,
)
@router.post("/sources/detach", tags=["sources"], response_model=SourceModel)
async def detach_source_from_agent(
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
source_name: str = Query(..., description="The name of the source to detach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
source = server.ms.get_source(source_id=source_id, user_id=user_id)
# create a directory connector that reads the in-memory file
connector = DirectoryConnector(input_files=[file.filename])
# load the data into the source via the connector
server.load_data(user_id=user_id, source_name=source.name, connector=connector)
# TODO: actually return added passages/documents
return UploadFileToSourceResponse(source=source, added_passages=0, added_documents=0)
@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(
source_id: uuid.UUID = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
raise NotImplementedError
@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents(
source_id: uuid.UUID = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
raise NotImplementedError
return router

View File

@@ -1075,13 +1075,25 @@ class SyncServer(LockingServer):
embedding_dim=self.config.default_embedding_config.embedding_dim,
)
self.ms.create_source(source)
assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}"
return source
def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID):
"""Delete a data source"""
source = self.ms.get_source(source_id=source_id, user_id=user_id)
self.ms.delete_source(source_id)
# delete data from passage store
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
passage_store.delete({"data_source": source.name})
# TODO: delete data from agent passage stores (?)
def load_data(
self,
user_id: uuid.UUID,
connector: DataConnector,
source_name: Source,
source_name: str,
):
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
@@ -1103,7 +1115,7 @@ class SyncServer(LockingServer):
# attach a data source to an agent
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
if data_source is None:
raise ValueError(f"Data source {source_name} does not exist")
raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}")
# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
@@ -1114,6 +1126,12 @@ class SyncServer(LockingServer):
# attach source to agent
agent.attach_source(data_source.name, source_connector, self.ms)
return data_source
def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
# TODO: remove all passages coresponding to source from agent's archival memory
raise NotImplementedError
def list_attached_sources(self, agent_id: uuid.UUID):
# list all attached sources to an agent
return self.ms.list_attached_sources(agent_id)

16
poetry.lock generated
View File

@@ -4103,6 +4103,20 @@ files = [
[package.extras]
cli = ["click (>=5.0)"]
[[package]]
name = "python-multipart"
version = "0.0.9"
description = "A streaming multipart parser for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
]
[package.extras]
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
[[package]]
name = "pytz"
version = "2023.4"
@@ -5898,4 +5912,4 @@ server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.12,>=3.10"
content-hash = "1b35809af89064c19823842ed40f5c8b9ceb8e68315fb482e2ae3b9f8cac0fad"
content-hash = "509bcb6fde67eb0c2d0a3997d6401e752d4c0cf7397a363f35167e6dc714ab0b"

View File

@@ -57,6 +57,7 @@ llama-index = "^0.10.6"
llama-index-embeddings-openai = "^0.1.1"
llama-index-embeddings-huggingface = {version = "^0.1.4", optional = true}
llama-index-embeddings-azure-openai = "^0.1.6"
python-multipart = "^0.0.9"
[tool.poetry.extras]
local = ["llama-index-embeddings-huggingface"]

View File

@@ -1,16 +1,15 @@
import uuid
import os
import time
import threading
from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
import pytest
import uuid
test_agent_name = f"test_client_{str(uuid.uuid4())}"
# test_preset_name = "test_preset"
test_preset_name = DEFAULT_PRESET
@@ -60,6 +59,7 @@ def user_token():
# Fixture to create clients with different configurations
@pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
def client(request, user_token):
# use token or not
if request.param["base_url"]:
@@ -71,6 +71,17 @@ def client(request, user_token):
yield client
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client):
agent_state = client.create_agent(name=test_agent_name, preset=test_preset_name)
print("AGENT ID", agent_state.id)
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
# TODO: add back once REST API supports
# def test_create_preset(client):
#
@@ -86,26 +97,55 @@ def client(request, user_token):
# client.create_preset(preset)
def test_create_agent(client):
global test_agent_state
test_agent_state = client.create_agent(
name=test_agent_name,
preset=test_preset_name,
)
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
assert test_agent_state is not None
# def test_create_agent(client):
# global test_agent_state
# test_agent_state = client.create_agent(
# name=test_agent_name,
# preset=test_preset_name,
# )
# print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
# assert test_agent_state is not None
def test_user_message(client):
"""Test that we can send a message through the client"""
assert client is not None, "Run create_agent test first"
print(f"\n\n[2] SENDING MESSAGE TO AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
response = client.user_message(agent_id=test_agent_state.id, message="Hello my name is Test, Client Test")
assert response is not None and len(response) > 0
def test_sources(client, agent):
# global test_agent_state_post_message
# client.server.active_agents[0]["agent"].update_state()
# test_agent_state_post_message = client.server.active_agents[0]["agent"].agent_state
# print(
# f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}"
# )
if not hasattr(client, "base_url"):
pytest.skip("Skipping test_sources because base_url is None")
# list sources
sources = client.list_sources()
print("listed sources", sources)
# create a source
source = client.create_source(name="test_source")
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 1
# load a file into a source
filename = "CONTRIBUTING.md"
response = client.load_file_into_source(filename, source.id)
print(response)
# attach a source
# TODO: make sure things run in the right order
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)
# TODO: list archival memory
# detach the source
# TODO: add when implemented
# client.detach_source(source.name, agent.id)
# delete the source
client.delete_source(source.id)
# def test_user_message(client, agent):
# """Test that we can send a message through the client"""
# assert client is not None, "Run create_agent test first"
# print(f"\n\n[2] SENDING MESSAGE TO AGENT {agent.id}!!!\n\tmessages={agent.state['messages']}")
# response = client.user_message(agent_id=agent.id, message="Hello my name is Test, Client Test")
# assert response is not None and len(response) > 0