fix: patch errors with OllamaProvider (#1875)
This commit is contained in:
@@ -140,9 +140,17 @@ class AnthropicProvider(Provider):
|
||||
|
||||
|
||||
class OllamaProvider(OpenAIProvider):
|
||||
"""Ollama provider that uses the native /api/generate endpoint
|
||||
|
||||
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
"""
|
||||
|
||||
name: str = "ollama"
|
||||
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
||||
default_prompt_formatter: str = Field(
|
||||
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
|
||||
)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
||||
@@ -156,11 +164,15 @@ class OllamaProvider(OpenAIProvider):
|
||||
configs = []
|
||||
for model in response_json["models"]:
|
||||
context_window = self.get_model_context_window(model["name"])
|
||||
if context_window is None:
|
||||
print(f"Ollama model {model['name']} has no context window")
|
||||
continue
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["name"],
|
||||
model_endpoint_type="ollama",
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
)
|
||||
)
|
||||
@@ -192,6 +204,10 @@ class OllamaProvider(OpenAIProvider):
|
||||
# ]
|
||||
# max_position_embeddings
|
||||
# parse model cards: nous, dolphon, llama
|
||||
if "model_info" not in response_json:
|
||||
if "error" in response_json:
|
||||
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
|
||||
return None
|
||||
for key, value in response_json["model_info"].items():
|
||||
if "context_length" in key:
|
||||
return value
|
||||
@@ -202,6 +218,10 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
|
||||
response_json = response.json()
|
||||
if "model_info" not in response_json:
|
||||
if "error" in response_json:
|
||||
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
|
||||
return None
|
||||
for key, value in response_json["model_info"].items():
|
||||
if "embedding_length" in key:
|
||||
return value
|
||||
@@ -220,6 +240,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
for model in response_json["models"]:
|
||||
embedding_dim = self.get_model_embedding_dim(model["name"])
|
||||
if not embedding_dim:
|
||||
print(f"Ollama model {model['name']} has no embedding dimension")
|
||||
continue
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
@@ -420,7 +441,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
name: str = "vllm"
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
# not supported with vLLM
|
||||
|
||||
Reference in New Issue
Block a user