From 4035a211fb47b4a2ef914f07d6640e6f508df005 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 15:42:59 -0700 Subject: [PATCH] wip --- letta/llm_api/llm_api_tools.py | 6 +++++- letta/schemas/llm_config.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 93753a55..35f7c8e2 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -455,7 +455,7 @@ def create( chat_completion_request=ChatCompletionRequest( model="command-r-plus", # TODO messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], - tools=[{"type": "function", "function": f} for f in functions] if functions else None, + tools=tools, tool_choice=function_call, # user=str(user_id), # NOTE: max_tokens is required for Anthropic API @@ -463,6 +463,10 @@ def create( ), ) + elif llm_config.model_endpoint_type == "groq": + if stream: + raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + # local model else: if stream: diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 134dff02..dfc68882 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, ConfigDict, Field @@ -16,7 +16,7 @@ class LLMConfig(BaseModel): """ # TODO: 🤮 don't default to a vendor! bug city! - model: str = Field(..., description="LLM model name. ") + model: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field(..., description="LLM model name. ") model_endpoint_type: str = Field(..., description="The endpoint type for the model.") model_endpoint: str = Field(..., description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")