This commit is contained in:
Matt Zhou
2024-10-01 15:42:59 -07:00
parent e368c1aa10
commit 4035a211fb
2 changed files with 7 additions and 3 deletions

View File

@@ -455,7 +455,7 @@ def create(
chat_completion_request=ChatCompletionRequest( chat_completion_request=ChatCompletionRequest(
model="command-r-plus", # TODO model="command-r-plus", # TODO
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None, tools=tools,
tool_choice=function_call, tool_choice=function_call,
# user=str(user_id), # user=str(user_id),
# NOTE: max_tokens is required for Anthropic API # NOTE: max_tokens is required for Anthropic API
@@ -463,6 +463,10 @@ def create(
), ),
) )
elif llm_config.model_endpoint_type == "groq":
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
# local model # local model
else: else:
if stream: if stream:

View File

@@ -1,4 +1,4 @@
from typing import Optional from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@@ -16,7 +16,7 @@ class LLMConfig(BaseModel):
""" """
# TODO: 🤮 don't default to a vendor! bug city! # TODO: 🤮 don't default to a vendor! bug city!
model: str = Field(..., description="LLM model name. ") model: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field(..., description="LLM model name. ")
model_endpoint_type: str = Field(..., description="The endpoint type for the model.") model_endpoint_type: str = Field(..., description="The endpoint type for the model.")
model_endpoint: str = Field(..., description="The endpoint for the model.") model_endpoint: str = Field(..., description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")