fix: patch errors with OllamaProvider (#1875)
This commit is contained in:
@@ -85,9 +85,7 @@ def get_chat_completion(
|
||||
elif wrapper is None:
|
||||
# Warn the user that we're using the fallback
|
||||
if not has_shown_warning:
|
||||
print(
|
||||
f"{CLI_WARNING_PREFIX}no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model-wrapper)"
|
||||
)
|
||||
print(f"{CLI_WARNING_PREFIX}no prompt formatter specified for local LLM, using the default formatter")
|
||||
has_shown_warning = True
|
||||
|
||||
llm_wrapper = DEFAULT_WRAPPER()
|
||||
|
||||
@@ -140,9 +140,17 @@ class AnthropicProvider(Provider):
|
||||
|
||||
|
||||
class OllamaProvider(OpenAIProvider):
|
||||
"""Ollama provider that uses the native /api/generate endpoint
|
||||
|
||||
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
"""
|
||||
|
||||
name: str = "ollama"
|
||||
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
||||
default_prompt_formatter: str = Field(
|
||||
..., description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
|
||||
)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
||||
@@ -156,11 +164,15 @@ class OllamaProvider(OpenAIProvider):
|
||||
configs = []
|
||||
for model in response_json["models"]:
|
||||
context_window = self.get_model_context_window(model["name"])
|
||||
if context_window is None:
|
||||
print(f"Ollama model {model['name']} has no context window")
|
||||
continue
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["name"],
|
||||
model_endpoint_type="ollama",
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
)
|
||||
)
|
||||
@@ -192,6 +204,10 @@ class OllamaProvider(OpenAIProvider):
|
||||
# ]
|
||||
# max_position_embeddings
|
||||
# parse model cards: nous, dolphon, llama
|
||||
if "model_info" not in response_json:
|
||||
if "error" in response_json:
|
||||
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
|
||||
return None
|
||||
for key, value in response_json["model_info"].items():
|
||||
if "context_length" in key:
|
||||
return value
|
||||
@@ -202,6 +218,10 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
|
||||
response_json = response.json()
|
||||
if "model_info" not in response_json:
|
||||
if "error" in response_json:
|
||||
print(f"Ollama fetch model info error for {model_name}: {response_json['error']}")
|
||||
return None
|
||||
for key, value in response_json["model_info"].items():
|
||||
if "embedding_length" in key:
|
||||
return value
|
||||
@@ -220,6 +240,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
for model in response_json["models"]:
|
||||
embedding_dim = self.get_model_embedding_dim(model["name"])
|
||||
if not embedding_dim:
|
||||
print(f"Ollama model {model['name']} has no embedding dimension")
|
||||
continue
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
@@ -420,7 +441,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
name: str = "vllm"
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
# not supported with vLLM
|
||||
|
||||
@@ -200,7 +200,7 @@ class SyncServer(Server):
|
||||
def __init__(
|
||||
self,
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: bool = None,
|
||||
max_chaining_steps: Optional[bool] = None,
|
||||
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
|
||||
# default_interface: AgentInterface = CLIInterface(),
|
||||
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
||||
@@ -241,13 +241,32 @@ class SyncServer(Server):
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||
if model_settings.openai_api_key:
|
||||
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base))
|
||||
self._enabled_providers.append(
|
||||
OpenAIProvider(
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
)
|
||||
if model_settings.anthropic_api_key:
|
||||
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
|
||||
self._enabled_providers.append(
|
||||
AnthropicProvider(
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.ollama_base_url:
|
||||
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None))
|
||||
self._enabled_providers.append(
|
||||
OllamaProvider(
|
||||
base_url=model_settings.ollama_base_url,
|
||||
api_key=None,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
)
|
||||
if model_settings.gemini_api_key:
|
||||
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
|
||||
self._enabled_providers.append(
|
||||
GoogleAIProvider(
|
||||
api_key=model_settings.gemini_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.azure_api_key and model_settings.azure_base_url:
|
||||
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
|
||||
self._enabled_providers.append(
|
||||
@@ -268,7 +287,11 @@ class SyncServer(Server):
|
||||
# NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup
|
||||
# see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html
|
||||
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
||||
self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base))
|
||||
self._enabled_providers.append(
|
||||
VLLMChatCompletionsProvider(
|
||||
base_url=model_settings.vllm_api_base,
|
||||
)
|
||||
)
|
||||
|
||||
def save_agents(self):
|
||||
"""Saves all the agents that are in the in-memory object store"""
|
||||
|
||||
@@ -6,19 +6,21 @@ from letta.providers import (
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
)
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
def test_openai():
|
||||
|
||||
provider = OpenAIProvider(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
|
||||
def test_anthropic():
|
||||
if os.getenv("ANTHROPIC_API_KEY") is None:
|
||||
return
|
||||
provider = AnthropicProvider(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(api_key=api_key)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
@@ -38,7 +40,9 @@ def test_azure():
|
||||
|
||||
|
||||
def test_ollama():
|
||||
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
|
||||
base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
assert base_url is not None
|
||||
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
@@ -47,7 +51,9 @@ def test_ollama():
|
||||
|
||||
|
||||
def test_googleai():
|
||||
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(api_key=api_key)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user