fix: patch errors with OllamaProvider (#1875)

This commit is contained in:
Charles Packer
2024-10-12 20:19:20 -07:00
committed by GitHub
parent 9d0da9549b
commit 1099809e49
4 changed files with 65 additions and 17 deletions

View File

@@ -140,9 +140,17 @@ class AnthropicProvider(Provider):
class OllamaProvider(OpenAIProvider):
"""Ollama provider that uses the native /api/generate endpoint
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
"""
name: str = "ollama"
base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field(
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
)
def list_llm_models(self) -> List[LLMConfig]:
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
@@ -156,11 +164,15 @@ class OllamaProvider(OpenAIProvider):
configs = []
for model in response_json["models"]:
context_window = self.get_model_context_window(model["name"])
if context_window is None:
print(f"Ollama model {model['name']} has no context window")
continue
configs.append(
LLMConfig(
model=model["name"],
model_endpoint_type="ollama",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
)
)
@@ -192,6 +204,10 @@ class OllamaProvider(OpenAIProvider):
# ]
# max_position_embeddings
# parse model cards: nous, dolphon, llama
if "model_info" not in response_json:
if "error" in response_json:
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
return None
for key, value in response_json["model_info"].items():
if "context_length" in key:
return value
@@ -202,6 +218,10 @@ class OllamaProvider(OpenAIProvider):
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()
if "model_info" not in response_json:
if "error" in response_json:
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
return None
for key, value in response_json["model_info"].items():
if "embedding_length" in key:
return value
@@ -220,6 +240,7 @@ class OllamaProvider(OpenAIProvider):
for model in response_json["models"]:
embedding_dim = self.get_model_embedding_dim(model["name"])
if not embedding_dim:
print(f"Ollama model {model['name']} has no embedding dimension")
continue
configs.append(
EmbeddingConfig(
@@ -420,7 +441,7 @@ class VLLMCompletionsProvider(Provider):
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
def list_llm_models(self) -> List[LLMConfig]:
# not supported with vLLM