feat: return server.server_llm_config information for REST endpoint (#1083)

This commit is contained in:
Sarah Wooders
2024-03-02 13:42:37 -08:00
committed by GitHub
parent 40b23987ed
commit bfea6ed7f4

View File

@@ -9,43 +9,32 @@ from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.models.pydantic_models import LLMConfigModel, EmbeddingConfigModel
router = APIRouter()
class Model(BaseModel):
name: str = Field(..., description="The name of the model.")
endpoint: str = Field(..., description="Endpoint URL for the model.")
endpoint_type: str = Field(..., description="Type of the model endpoint.")
wrapper: str = Field(None, description="Wrapper used for the model.")
context_window: int = Field(..., description="Context window size for the model.")
class ListModelsResponse(BaseModel):
models: List[Model] = Field(..., description="List of model configurations.")
models: List[LLMConfigModel] = Field(..., description="List of model configurations.")
def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/models", tags=["models"], response_model=ListModelsResponse)
async def list_models(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
async def list_models():
# Clear the interface
interface.clear()
# TODO: Replace with actual data fetching logic once available
models_data = [
Model(
name="ehartford/dolphin-2.5-mixtral-8x7b",
endpoint="https://api.memgpt.ai",
endpoint_type="vllm",
wrapper="chatml",
context_window=16384,
),
Model(name="gpt-4", endpoint="https://api.openai.com/v1", endpoint_type="openai", context_window=8192),
]
# currently, the server only supports one model, however this may change in the future
llm_config = LLMConfigModel(
model=server.server_llm_config.model,
model_endpoint=server.server_llm_config.model_endpoint,
model_endpoint_type=server.server_llm_config.model_endpoint_type,
model_wrapper=server.server_llm_config.model_wrapper,
context_window=server.server_llm_config.context_window,
)
return ListModelsResponse(models=models_data)
return ListModelsResponse(models=[llm_config])
return router