chore: merge changes from oss (#964)
This commit is contained in:
86
.github/workflows/tests.yml
vendored
86
.github/workflows/tests.yml
vendored
@@ -1,86 +0,0 @@
|
||||
name: Unit Tests
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
COMPOSIO_API_KEY: ${{ secrets.COMPOSIO_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
||||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
|
||||
E2B_SANDBOX_TEMPLATE_ID: ${{ secrets.E2B_SANDBOX_TEMPLATE_ID }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
unit-run:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test_suite:
|
||||
- "test_vector_embeddings.py"
|
||||
- "test_client.py"
|
||||
- "test_client_legacy.py"
|
||||
- "test_server.py"
|
||||
- "test_v1_routes.py"
|
||||
- "test_local_client.py"
|
||||
- "test_managers.py"
|
||||
- "test_base_functions.py"
|
||||
- "test_tool_schema_parsing.py"
|
||||
- "test_tool_rule_solver.py"
|
||||
- "test_memory.py"
|
||||
- "test_utils.py"
|
||||
- "test_stream_buffer_readers.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
ports:
|
||||
- 6333:6333
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_USER: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python, Poetry, and Dependencies
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
|
||||
- name: Migrate database
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
run: |
|
||||
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
|
||||
poetry run alembic upgrade head
|
||||
- name: Run core unit tests
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/${{ matrix.test_suite }}
|
||||
@@ -1,5 +1,4 @@
|
||||
__version__ = "0.6.13"
|
||||
|
||||
__version__ = "0.6.23"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
@@ -260,6 +260,7 @@ class Agent(BaseAgent):
|
||||
error_msg: str,
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
function_response: str,
|
||||
messages: List[Message],
|
||||
include_function_failed_message: bool = False,
|
||||
@@ -394,6 +395,7 @@ class Agent(BaseAgent):
|
||||
|
||||
messages = [] # append these to the history when done
|
||||
function_name = None
|
||||
function_args = {}
|
||||
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0):
|
||||
@@ -431,6 +433,7 @@ class Agent(BaseAgent):
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.logger.info(f"Function call message: {messages[-1]}")
|
||||
|
||||
nonnull_content = False
|
||||
if response_message.content:
|
||||
@@ -445,6 +448,7 @@ class Agent(BaseAgent):
|
||||
response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function
|
||||
)
|
||||
function_name = function_call.name
|
||||
self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
|
||||
|
||||
# Failure case 1: function name is wrong (not in agent_state.tools)
|
||||
target_letta_tool = None
|
||||
@@ -455,7 +459,9 @@ class Agent(BaseAgent):
|
||||
if not target_letta_tool:
|
||||
error_msg = f"No function named {function_name}"
|
||||
function_response = "None" # more like "never ran?"
|
||||
messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages)
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# Failure case 2: function name is OK, but function args are bad JSON
|
||||
@@ -465,7 +471,9 @@ class Agent(BaseAgent):
|
||||
except Exception:
|
||||
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
|
||||
function_response = "None" # more like "never ran?"
|
||||
messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages)
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure)
|
||||
@@ -502,7 +510,7 @@ class Agent(BaseAgent):
|
||||
|
||||
if sandbox_run_result and sandbox_run_result.status == "error":
|
||||
messages = self._handle_function_error_response(
|
||||
function_response, tool_call_id, function_name, function_response, messages
|
||||
function_response, tool_call_id, function_name, function_args, function_response, messages
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -531,7 +539,7 @@ class Agent(BaseAgent):
|
||||
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
||||
self.logger.error(error_msg_user)
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages, include_function_failed_message=True
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -539,7 +547,7 @@ class Agent(BaseAgent):
|
||||
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
|
||||
error_msg = function_response_string
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages, include_function_failed_message=True
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import OptionState
|
||||
from letta.schemas.memory import ChatMemory, Memory
|
||||
from letta.server.server import logger as server_logger
|
||||
|
||||
# from letta.interface import CLIInterface as interface # for printing to terminal
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
|
||||
@@ -119,6 +118,8 @@ def run(
|
||||
utils.DEBUG = debug
|
||||
# TODO: add logging command line options for runtime log level
|
||||
|
||||
from letta.server.server import logger as server_logger
|
||||
|
||||
if debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
server_logger.setLevel(logging.DEBUG)
|
||||
@@ -360,4 +361,4 @@ def delete_agent(
|
||||
def version() -> str:
|
||||
import letta
|
||||
|
||||
return letta.__version__
|
||||
print(letta.__version__)
|
||||
|
||||
@@ -167,6 +167,27 @@ class OllamaEmbeddings:
|
||||
return response_json["embedding"]
|
||||
|
||||
|
||||
class GoogleEmbeddings:
|
||||
def __init__(self, api_key: str, model: str, base_url: str):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url # Expected to be "https://generativelanguage.googleapis.com"
|
||||
|
||||
def get_text_embedding(self, text: str):
|
||||
import httpx
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
# Build the URL based on the provided base_url, model, and API key.
|
||||
url = f"{self.base_url}/v1beta/models/{self.model}:embedContent?key={self.api_key}"
|
||||
payload = {"model": self.model, "content": {"parts": [{"text": text}]}}
|
||||
with httpx.Client() as client:
|
||||
response = client.post(url, headers=headers, json=payload)
|
||||
# Raise an error for non-success HTTP status codes.
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
return response_json["embedding"]["values"]
|
||||
|
||||
|
||||
def query_embedding(embedding_model, query_text: str):
|
||||
"""Generate padded embedding for querying database"""
|
||||
query_vec = embedding_model.get_text_embedding(query_text)
|
||||
@@ -237,5 +258,14 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
)
|
||||
return model
|
||||
|
||||
elif endpoint_type == "google_ai":
|
||||
assert all([model_settings.gemini_api_key is not None, model_settings.gemini_base_url is not None])
|
||||
model = GoogleEmbeddings(
|
||||
model=config.embedding_model,
|
||||
api_key=model_settings.gemini_api_key,
|
||||
base_url=model_settings.gemini_base_url,
|
||||
)
|
||||
return model
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown endpoint type {endpoint_type}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from letta.cli.cli import version
|
||||
from letta import __version__
|
||||
from letta.schemas.health import Health
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -15,6 +15,6 @@ router = APIRouter(prefix="/health", tags=["health"])
|
||||
@router.get("/", response_model=Health, operation_id="health_check")
|
||||
def health_check():
|
||||
return Health(
|
||||
version=version(),
|
||||
version=__version__,
|
||||
status="ok",
|
||||
)
|
||||
|
||||
@@ -85,7 +85,7 @@ class ModelSettings(BaseSettings):
|
||||
|
||||
# google ai
|
||||
gemini_api_key: Optional[str] = None
|
||||
|
||||
gemini_base_url: str = "https://generativelanguage.googleapis.com/"
|
||||
# together
|
||||
together_api_key: Optional[str] = None
|
||||
|
||||
|
||||
885
poetry.lock
generated
885
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
|
||||
|
||||
version = "0.6.13"
|
||||
version = "0.6.23"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -82,6 +80,7 @@ anthropic = "^0.43.0"
|
||||
letta_client = "^0.1.23"
|
||||
openai = "^1.60.0"
|
||||
faker = "^36.1.0"
|
||||
colorama = "^0.4.6"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
@@ -74,3 +75,16 @@ def test_letta_run_create_new_agent(swap_letta_config):
|
||||
# Count occurrences of assistant messages
|
||||
robot = full_output.count(ASSISTANT_MESSAGE_CLI_SYMBOL)
|
||||
assert robot == 1, f"It appears that there are multiple instances of assistant messages outputted."
|
||||
|
||||
|
||||
def test_letta_version_prints_only_version(swap_letta_config):
|
||||
# Start the letta version command
|
||||
output = pexpect.run("poetry run letta version", encoding="utf-8")
|
||||
|
||||
# Remove ANSI escape sequences and whitespace
|
||||
output = re.sub(r"\x1b\[[0-9;]*[mK]", "", output).strip()
|
||||
|
||||
from letta import __version__
|
||||
|
||||
# Get the full output and verify it contains only the version
|
||||
assert output == __version__, f"Expected only '{__version__}', but got '{repr(output)}'"
|
||||
|
||||
153
tests/test_google_embeddings.py
Normal file
153
tests/test_google_embeddings.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import httpx
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from letta.embeddings import GoogleEmbeddings # Adjust the import based on your module structure
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import MessageCreate
|
||||
|
||||
SERVER_PORT = 8283
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client() -> LettaSDKClient:
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
client = LettaSDKClient(base_url=server_url, token=None)
|
||||
yield client
|
||||
|
||||
|
||||
def test_google_embeddings_response():
|
||||
api_key = os.environ.get("GEMINI_API_KEY")
|
||||
model = "text-embedding-004"
|
||||
base_url = "https://generativelanguage.googleapis.com"
|
||||
text = "Hello, world!"
|
||||
|
||||
embedding_model = GoogleEmbeddings(api_key, model, base_url)
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = embedding_model.get_text_embedding(text)
|
||||
except httpx.HTTPStatusError as e:
|
||||
pytest.fail(f"Request failed with status code {e.response.status_code}")
|
||||
|
||||
assert response is not None, "No response received from API"
|
||||
assert isinstance(response, list), "Response is not a list of embeddings"
|
||||
|
||||
|
||||
def test_archival_insert_text_embedding_004(client: LettaSDKClient):
|
||||
"""
|
||||
Test that an agent with model 'gemini-2.0-flash-exp' and embedding 'text_embedding_004'
|
||||
correctly inserts a message into its archival memory.
|
||||
|
||||
The test works by:
|
||||
1. Creating an agent with the desired model and embedding.
|
||||
2. Sending a message prefixed with 'archive :' to instruct the agent to store the message in archival.
|
||||
3. Retrieving the archival memory via the agent messaging API.
|
||||
4. Verifying that the archival message is stored.
|
||||
"""
|
||||
# Create an agent with the specified model and embedding.
|
||||
agent = client.agents.create(
|
||||
name=f"archival_insert_text_embedding_004",
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="name: archival_test"),
|
||||
CreateBlock(label="persona", value="You are a helpful assistant that loves helping out the user"),
|
||||
],
|
||||
model="google_ai/gemini-2.0-flash-exp",
|
||||
embedding="google_ai/text-embedding-004",
|
||||
)
|
||||
|
||||
# Define the archival message.
|
||||
archival_message = "Archival insertion test message"
|
||||
|
||||
# Send a message instructing the agent to archive it.
|
||||
res = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=f"Store this in your archive memory: {archival_message}")],
|
||||
)
|
||||
print(res.messages)
|
||||
|
||||
# Retrieve the archival messages through the agent messaging API.
|
||||
archived_messages = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=f"retrieve from archival memory : {archival_message}")],
|
||||
)
|
||||
|
||||
print(archived_messages.messages)
|
||||
# Assert that the archival message is present.
|
||||
assert any(
|
||||
message.status == "success" for message in archived_messages.messages if message.message_type == "tool_return_message"
|
||||
), f"Archival message '{archival_message}' not found. Archived messages: {archived_messages}"
|
||||
|
||||
# Cleanup: Delete the agent.
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
|
||||
def test_archival_insert_embedding_001(client: LettaSDKClient):
|
||||
"""
|
||||
Test that an agent with model 'gemini-2.0-flash-exp' and embedding 'embedding_001'
|
||||
correctly inserts a message into its archival memory.
|
||||
|
||||
The test works by:
|
||||
1. Creating an agent with the desired model and embedding.
|
||||
2. Sending a message prefixed with 'archive :' to instruct the agent to store the message in archival.
|
||||
3. Retrieving the archival memory via the agent messaging API.
|
||||
4. Verifying that the archival message is stored.
|
||||
"""
|
||||
# Create an agent with the specified model and embedding.
|
||||
agent = client.agents.create(
|
||||
name=f"archival_insert_embedding_001",
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="name: archival_test"),
|
||||
CreateBlock(label="persona", value="You are a helpful assistant that loves helping out the user"),
|
||||
],
|
||||
model="google_ai/gemini-2.0-flash-exp",
|
||||
embedding="google_ai/embedding-001",
|
||||
)
|
||||
|
||||
# Define the archival message.
|
||||
archival_message = "Archival insertion test message"
|
||||
|
||||
# Send a message instructing the agent to archive it.
|
||||
client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=f"archive : {archival_message}")],
|
||||
)
|
||||
|
||||
# Retrieve the archival messages through the agent messaging API.
|
||||
archived_messages = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content=f"retrieve from archival memory : {archival_message}")],
|
||||
)
|
||||
|
||||
# Assert that the archival message is present.
|
||||
assert any(
|
||||
message.status == "success" for message in archived_messages.messages if message.message_type == "tool_return_message"
|
||||
), f"Archival message '{archival_message}' not found. Archived messages: {archived_messages}"
|
||||
|
||||
# Cleanup: Delete the agent.
|
||||
client.agents.delete(agent.id)
|
||||
Reference in New Issue
Block a user