feat: Add data sources to REST API (#1118)
This commit is contained in:
@@ -115,12 +115,11 @@ def load_directory(
|
||||
document_store=None,
|
||||
passage_store=passage_storage,
|
||||
)
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
except Exception as e:
|
||||
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
||||
ms.delete_source(source_id=source.id)
|
||||
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
|
||||
@@ -3,13 +3,14 @@ import requests
|
||||
import uuid
|
||||
from typing import Dict, List, Union, Optional, Tuple
|
||||
|
||||
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig
|
||||
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
|
||||
from memgpt.cli.cli import QuickstartChoice
|
||||
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
|
||||
|
||||
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
|
||||
@@ -69,6 +70,30 @@ class AbstractClient(object):
|
||||
def save(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def list_sources(self):
|
||||
"""List loaded sources"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_source(self):
|
||||
"""Delete a source and associated data (including attached to agents)"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
|
||||
"""Load {filename} and insert into source"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_source(self, name: str):
|
||||
"""Create a new source"""
|
||||
raise NotImplementedError
|
||||
|
||||
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
"""Attach a source to an agent"""
|
||||
raise NotImplementedError
|
||||
|
||||
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
"""Detach a source from an agent"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RESTClient(AbstractClient):
|
||||
def __init__(
|
||||
@@ -127,17 +152,17 @@ class RESTClient(AbstractClient):
|
||||
)
|
||||
return agent_state
|
||||
|
||||
def delete_agent(self, agent_id: str):
|
||||
response = requests.delete(f"{self.base_url}/api/agents/{agent_id}", headers=self.headers)
|
||||
return agent_id
|
||||
def delete_agent(self, agent_id: uuid.UUID):
|
||||
response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete agent: {response.text}"
|
||||
|
||||
def create_preset(self, preset: Preset):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> AgentState:
|
||||
def get_agent_config(self, agent_id: uuid.UUID) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_agent_memory(self, agent_id: str) -> Dict:
|
||||
def get_agent_memory(self, agent_id: uuid.UUID) -> Dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> Dict:
|
||||
@@ -157,6 +182,53 @@ class RESTClient(AbstractClient):
|
||||
def save(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def list_sources(self):
|
||||
"""List loaded sources"""
|
||||
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
|
||||
response_json = response.json()
|
||||
return response_json
|
||||
|
||||
def delete_source(self, source_id: uuid.UUID):
|
||||
"""Delete a source and associated data (including attached to agents)"""
|
||||
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete source: {response.text}"
|
||||
|
||||
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
|
||||
"""Load {filename} and insert into source"""
|
||||
params = {"source_id": str(source_id)}
|
||||
files = {"file": open(filename, "rb")}
|
||||
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
|
||||
return response.json()
|
||||
|
||||
def create_source(self, name: str) -> Source:
|
||||
"""Create a new source"""
|
||||
payload = {"name": name}
|
||||
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
|
||||
response_json = response.json()
|
||||
print("CREATE SOURCE", response_json, response.text)
|
||||
return Source(
|
||||
id=uuid.UUID(response_json["id"]),
|
||||
name=response_json["name"],
|
||||
user_id=uuid.UUID(response_json["user_id"]),
|
||||
created_at=datetime.datetime.fromtimestamp(response_json["created_at"]),
|
||||
embedding_dim=response_json["embedding_config"]["embedding_dim"],
|
||||
embedding_model=response_json["embedding_config"]["embedding_model"],
|
||||
)
|
||||
|
||||
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
|
||||
"""Attach a source to an agent"""
|
||||
params = {"source_name": source_name, "agent_id": agent_id}
|
||||
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
||||
return response.json()
|
||||
|
||||
def detach_source(self, source_name: str, agent_id: uuid.UUID):
|
||||
"""Detach a source from an agent"""
|
||||
params = {"source_name": source_name, "agent_id": str(agent_id)}
|
||||
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||
return response.json()
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
def __init__(
|
||||
@@ -267,3 +339,15 @@ class LocalClient(AbstractClient):
|
||||
|
||||
def save(self):
|
||||
self.server.save_agents()
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
self.server.load_data(user_id=self.user_id, connector=connector, source_name=source_name)
|
||||
|
||||
def create_source(self, name: str):
|
||||
self.server.create_source(user_id=self.user_id, name=name)
|
||||
|
||||
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
|
||||
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)
|
||||
|
||||
def delete_agent(self, agent_id: uuid.UUID):
|
||||
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
|
||||
|
||||
@@ -5,6 +5,7 @@ from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Document, Passage
|
||||
|
||||
from typing import List, Iterator, Dict, Tuple, Optional
|
||||
import typer
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
|
||||
|
||||
@@ -53,7 +54,15 @@ def load_data(
|
||||
|
||||
# generate passages
|
||||
for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
|
||||
embedding = embed_model.get_text_embedding(passage_text)
|
||||
try:
|
||||
embedding = embed_model.get_text_embedding(passage_text)
|
||||
except Exception as e:
|
||||
typer.secho(
|
||||
f"Warning: Failed to get embedding for {passage_text} (error: {str(e)}), skipping insert into VectorDB.",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
continue
|
||||
|
||||
passage = Passage(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
text=passage_text,
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import List, Optional, Dict, Literal
|
||||
from pydantic import BaseModel, Field, Json, ConfigDict
|
||||
import uuid
|
||||
import base64
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlalchemy import JSON, Column, BINARY, TypeDecorator
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
@@ -83,3 +86,38 @@ class PersonaModel(SQLModel, table=True):
|
||||
name: str = Field(..., description="The name of the persona.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
|
||||
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")
|
||||
|
||||
|
||||
class SourceModel(SQLModel, table=True):
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
description: str = Field(None, description="The description of the source.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the source was created.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
|
||||
# embedding info
|
||||
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
|
||||
embedding_config: Optional[EmbeddingConfigModel] = Field(
|
||||
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
|
||||
)
|
||||
|
||||
|
||||
class PassageModel(BaseModel):
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
|
||||
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
|
||||
text: str = Field(..., description="The text of the passage.")
|
||||
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfigModel] = Field(
|
||||
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
|
||||
)
|
||||
data_source: Optional[str] = Field(None, description="The data source of the passage.")
|
||||
doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True)
|
||||
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
||||
|
||||
|
||||
class DocumentModel(BaseModel):
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.")
|
||||
text: str = Field(..., description="The text of the document.")
|
||||
data_source: str = Field(..., description="The data source of the document.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
|
||||
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
|
||||
|
||||
@@ -100,10 +100,8 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token)
|
||||
|
||||
@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
|
||||
def delete_user(
|
||||
user_id: str = Query(..., description="The ID of the user to be deleted."),
|
||||
):
|
||||
@router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse)
|
||||
def delete_user(user_id):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
user_id_uuid = uuid.UUID(user_id)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import re
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
@@ -20,6 +20,7 @@ class GetAgentRequest(BaseModel):
|
||||
|
||||
|
||||
class AgentRenameRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
|
||||
agent_name: str = Field(..., description="New name for the agent.")
|
||||
|
||||
|
||||
@@ -50,9 +51,9 @@ def validate_agent_name(name: str) -> str:
|
||||
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse)
|
||||
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
|
||||
def get_agent_config(
|
||||
agent_id: uuid.UUID,
|
||||
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
@@ -68,8 +69,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
interface.clear()
|
||||
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
||||
# return GetAgentResponse(agent_state=agent_state)
|
||||
LLMConfigModel(**vars(agent_state.llm_config))
|
||||
EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
return GetAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
@@ -89,9 +90,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
sources=attached_sources,
|
||||
)
|
||||
|
||||
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||
def update_agent_name(
|
||||
agent_id: uuid.UUID,
|
||||
request: AgentRenameRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
@@ -100,6 +100,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
This changes the name of the agent in the database but does NOT edit the agent's persona.
|
||||
"""
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
valid_name = validate_agent_name(request.agent_name)
|
||||
|
||||
interface.clear()
|
||||
@@ -113,15 +115,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
@router.delete("/agents/{agent_id}", tags=["agents"])
|
||||
def delete_agent(
|
||||
agent_id: uuid.UUID,
|
||||
agent_id,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
request = GetAgentRequest(agent_id=agent_id)
|
||||
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
agent_id = uuid.UUID(agent_id)
|
||||
|
||||
interface.clear()
|
||||
try:
|
||||
|
||||
@@ -21,6 +21,7 @@ from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_ass
|
||||
from memgpt.server.rest_api.personas.index import setup_personas_index_router
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.server.rest_api.tools.index import setup_tools_index_router
|
||||
from memgpt.server.rest_api.sources.index import setup_sources_index_router
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
"""
|
||||
@@ -94,6 +95,7 @@ app.include_router(setup_humans_index_router(server, interface, password), prefi
|
||||
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
# /api/config endpoints
|
||||
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
0
memgpt/server/rest_api/sources/__init__.py
Normal file
0
memgpt/server/rest_api/sources/__init__.py
Normal file
165
memgpt/server/rest_api/sources/index.py
Normal file
165
memgpt/server/rest_api/sources/index.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import Source
|
||||
from memgpt.data_sources.connectors import DirectoryConnector
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
"""
|
||||
Implement the following functions:
|
||||
* List all available sources
|
||||
* Create a new source
|
||||
* Delete a source
|
||||
* Upload a file to a server that is loaded into a specific source
|
||||
* Paginated get all passages from a source
|
||||
* Paginated get all documents from a source
|
||||
* Attach a source to an agent
|
||||
"""
|
||||
|
||||
|
||||
class ListSourcesResponse(BaseModel):
|
||||
sources: List[SourceModel] = Field(..., description="List of available sources")
|
||||
|
||||
|
||||
class CreateSourceRequest(BaseModel):
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
|
||||
|
||||
class CreateSourceResponse(BaseModel):
|
||||
source: SourceModel = Field(..., description="The newly created source.")
|
||||
|
||||
|
||||
class UploadFileToSourceRequest(BaseModel):
|
||||
file: UploadFile = Field(..., description="The file to upload.")
|
||||
|
||||
|
||||
class UploadFileToSourceResponse(BaseModel):
|
||||
source: SourceModel = Field(..., description="The source the file was uploaded to.")
|
||||
added_passages: int = Field(..., description="The number of passages added to the source.")
|
||||
added_documents: int = Field(..., description="The number of documents added to the source.")
|
||||
|
||||
|
||||
class GetSourcePassagesResponse(BaseModel):
|
||||
passages: List[PassageModel] = Field(..., description="List of passages from the source.")
|
||||
|
||||
|
||||
class GetSourceDocumentsResponse(BaseModel):
|
||||
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
|
||||
|
||||
|
||||
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
|
||||
async def list_source(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
sources = server.ms.list_sources(user_id=user_id)
|
||||
return ListSourcesResponse(sources=sources)
|
||||
|
||||
@router.post("/sources", tags=["sources"], response_model=SourceModel)
|
||||
async def create_source(
|
||||
request: CreateSourceRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
|
||||
source = server.create_source(name=request.name, user_id=user_id)
|
||||
return SourceModel(
|
||||
name=source.name,
|
||||
description=None, # TODO: actually store descriptions
|
||||
user_id=source.user_id,
|
||||
id=source.id,
|
||||
embedding_config=server.server_embedding_config,
|
||||
created_at=source.created_at.timestamp(),
|
||||
)
|
||||
|
||||
@router.delete("/sources/{source_id}", tags=["sources"])
|
||||
async def delete_source(
|
||||
source_id,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources/attach", tags=["sources"], response_model=SourceModel)
|
||||
async def attach_source_to_agent(
|
||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
source_name: str = Query(..., description="The name of the source to attach."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
||||
source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
|
||||
return SourceModel(
|
||||
name=source.name,
|
||||
description=None, # TODO: actually store descriptions
|
||||
user_id=source.user_id,
|
||||
id=source.id,
|
||||
embedding_config=server.server_embedding_config,
|
||||
created_at=source.created_at,
|
||||
)
|
||||
|
||||
@router.post("/sources/detach", tags=["sources"], response_model=SourceModel)
|
||||
async def detach_source_from_agent(
|
||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
source_name: str = Query(..., description="The name of the source to detach."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
|
||||
|
||||
@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
|
||||
async def upload_file_to_source(
|
||||
# file: UploadFile = UploadFile(..., description="The file to upload."),
|
||||
file: UploadFile,
|
||||
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
source = server.ms.get_source(source_id=source_id, user_id=user_id)
|
||||
|
||||
# create a directory connector that reads the in-memory file
|
||||
connector = DirectoryConnector(input_files=[file.filename])
|
||||
|
||||
# load the data into the source via the connector
|
||||
server.load_data(user_id=user_id, source_name=source.name, connector=connector)
|
||||
|
||||
# TODO: actually return added passages/documents
|
||||
return UploadFileToSourceResponse(source=source, added_passages=0, added_documents=0)
|
||||
|
||||
@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
||||
async def list_passages(
|
||||
source_id: uuid.UUID = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
|
||||
async def list_documents(
|
||||
source_id: uuid.UUID = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
return router
|
||||
@@ -1075,13 +1075,25 @@ class SyncServer(LockingServer):
|
||||
embedding_dim=self.config.default_embedding_config.embedding_dim,
|
||||
)
|
||||
self.ms.create_source(source)
|
||||
assert self.ms.get_source(source_name=name, user_id=user_id) is not None, f"Failed to create source {name}"
|
||||
return source
|
||||
|
||||
def delete_source(self, source_id: uuid.UUID, user_id: uuid.UUID):
|
||||
"""Delete a data source"""
|
||||
source = self.ms.get_source(source_id=source_id, user_id=user_id)
|
||||
self.ms.delete_source(source_id)
|
||||
|
||||
# delete data from passage store
|
||||
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
passage_store.delete({"data_source": source.name})
|
||||
|
||||
# TODO: delete data from agent passage stores (?)
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
connector: DataConnector,
|
||||
source_name: Source,
|
||||
source_name: str,
|
||||
):
|
||||
"""Load data from a DataConnector into a source for a specified user_id"""
|
||||
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
||||
@@ -1103,7 +1115,7 @@ class SyncServer(LockingServer):
|
||||
# attach a data source to an agent
|
||||
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
||||
if data_source is None:
|
||||
raise ValueError(f"Data source {source_name} does not exist")
|
||||
raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}")
|
||||
|
||||
# get connection to data source storage
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
@@ -1114,6 +1126,12 @@ class SyncServer(LockingServer):
|
||||
# attach source to agent
|
||||
agent.attach_source(data_source.name, source_connector, self.ms)
|
||||
|
||||
return data_source
|
||||
|
||||
def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
||||
# TODO: remove all passages coresponding to source from agent's archival memory
|
||||
raise NotImplementedError
|
||||
|
||||
def list_attached_sources(self, agent_id: uuid.UUID):
|
||||
# list all attached sources to an agent
|
||||
return self.ms.list_attached_sources(agent_id)
|
||||
|
||||
16
poetry.lock
generated
16
poetry.lock
generated
@@ -4103,6 +4103,20 @@ files = [
|
||||
[package.extras]
|
||||
cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.9"
|
||||
description = "A streaming multipart parser for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
|
||||
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2023.4"
|
||||
@@ -5898,4 +5912,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.10"
|
||||
content-hash = "1b35809af89064c19823842ed40f5c8b9ceb8e68315fb482e2ae3b9f8cac0fad"
|
||||
content-hash = "509bcb6fde67eb0c2d0a3997d6401e752d4c0cf7397a363f35167e6dc714ab0b"
|
||||
|
||||
@@ -57,6 +57,7 @@ llama-index = "^0.10.6"
|
||||
llama-index-embeddings-openai = "^0.1.1"
|
||||
llama-index-embeddings-huggingface = {version = "^0.1.4", optional = true}
|
||||
llama-index-embeddings-azure-openai = "^0.1.6"
|
||||
python-multipart = "^0.0.9"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["llama-index-embeddings-huggingface"]
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import uuid
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
# test_preset_name = "test_preset"
|
||||
test_preset_name = DEFAULT_PRESET
|
||||
@@ -60,6 +59,7 @@ def user_token():
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
|
||||
# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
|
||||
def client(request, user_token):
|
||||
# use token or not
|
||||
if request.param["base_url"]:
|
||||
@@ -71,6 +71,17 @@ def client(request, user_token):
|
||||
yield client
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client):
|
||||
agent_state = client.create_agent(name=test_agent_name, preset=test_preset_name)
|
||||
print("AGENT ID", agent_state.id)
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
# TODO: add back once REST API supports
|
||||
# def test_create_preset(client):
|
||||
#
|
||||
@@ -86,26 +97,55 @@ def client(request, user_token):
|
||||
# client.create_preset(preset)
|
||||
|
||||
|
||||
def test_create_agent(client):
|
||||
global test_agent_state
|
||||
test_agent_state = client.create_agent(
|
||||
name=test_agent_name,
|
||||
preset=test_preset_name,
|
||||
)
|
||||
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
||||
assert test_agent_state is not None
|
||||
# def test_create_agent(client):
|
||||
# global test_agent_state
|
||||
# test_agent_state = client.create_agent(
|
||||
# name=test_agent_name,
|
||||
# preset=test_preset_name,
|
||||
# )
|
||||
# print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
||||
# assert test_agent_state is not None
|
||||
|
||||
|
||||
def test_user_message(client):
|
||||
"""Test that we can send a message through the client"""
|
||||
assert client is not None, "Run create_agent test first"
|
||||
print(f"\n\n[2] SENDING MESSAGE TO AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
||||
response = client.user_message(agent_id=test_agent_state.id, message="Hello my name is Test, Client Test")
|
||||
assert response is not None and len(response) > 0
|
||||
def test_sources(client, agent):
|
||||
|
||||
# global test_agent_state_post_message
|
||||
# client.server.active_agents[0]["agent"].update_state()
|
||||
# test_agent_state_post_message = client.server.active_agents[0]["agent"].agent_state
|
||||
# print(
|
||||
# f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}"
|
||||
# )
|
||||
if not hasattr(client, "base_url"):
|
||||
pytest.skip("Skipping test_sources because base_url is None")
|
||||
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
print("listed sources", sources)
|
||||
|
||||
# create a source
|
||||
source = client.create_source(name="test_source")
|
||||
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
print("listed sources", sources)
|
||||
assert len(sources) == 1
|
||||
|
||||
# load a file into a source
|
||||
filename = "CONTRIBUTING.md"
|
||||
response = client.load_file_into_source(filename, source.id)
|
||||
print(response)
|
||||
|
||||
# attach a source
|
||||
# TODO: make sure things run in the right order
|
||||
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)
|
||||
|
||||
# TODO: list archival memory
|
||||
|
||||
# detach the source
|
||||
# TODO: add when implemented
|
||||
# client.detach_source(source.name, agent.id)
|
||||
|
||||
# delete the source
|
||||
client.delete_source(source.id)
|
||||
|
||||
|
||||
# def test_user_message(client, agent):
|
||||
# """Test that we can send a message through the client"""
|
||||
# assert client is not None, "Run create_agent test first"
|
||||
# print(f"\n\n[2] SENDING MESSAGE TO AGENT {agent.id}!!!\n\tmessages={agent.state['messages']}")
|
||||
# response = client.user_message(agent_id=agent.id, message="Hello my name is Test, Client Test")
|
||||
# assert response is not None and len(response) > 0
|
||||
|
||||
Reference in New Issue
Block a user