From 698d99a66e2b4dc3efd08b8ff9474a1c5d3d07c8 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 18 Jun 2025 15:11:30 -0700 Subject: [PATCH] feat: Ungate file upload for simple MIME types even without Mistral API key (#2898) --- examples/docs/agent_advanced.py | 2 +- examples/docs/agent_basic.py | 2 +- examples/docs/example.py | 4 +- examples/docs/tools.py | 4 +- examples/mcp_example.py | 2 +- .../sleeptime/sleeptime_source_example.py | 2 +- letta/constants.py | 2 +- letta/schemas/embedding_config.py | 8 ++++ letta/server/rest_api/routers/v1/sources.py | 41 ++++++++++++------- .../embedder/openai_embedder.py | 2 +- .../file_processor/parser/mistral_parser.py | 4 +- paper_experiments/doc_qa_task/doc_qa.py | 2 +- tests/integration_test_sleeptime_agent.py | 8 ++-- tests/integration_test_voice_agent.py | 4 +- tests/sdk/agents_test.py | 2 +- tests/sdk/conftest.py | 2 +- tests/test_cli.py | 2 +- tests/test_multi_agent.py | 14 +++---- tests/test_sdk_client.py | 22 +++++----- tests/test_server.py | 20 ++++----- tests/test_sources.py | 22 +++++----- 21 files changed, 95 insertions(+), 76 deletions(-) diff --git a/examples/docs/agent_advanced.py b/examples/docs/agent_advanced.py index 94638327..a143076a 100644 --- a/examples/docs/agent_advanced.py +++ b/examples/docs/agent_advanced.py @@ -30,7 +30,7 @@ agent_state = client.agents.create( model="openai/gpt-4o-mini", context_window_limit=8000, # embedding model & endpoint configuration (cannot be changed) - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", # system instructions for the agent (defaults to `memgpt_chat`) system=gpt_system.get_system_text("memgpt_chat"), # whether to include base letta tools (default: True) diff --git a/examples/docs/agent_basic.py b/examples/docs/agent_basic.py index a6d26ed7..eb0bd952 100644 --- a/examples/docs/agent_basic.py +++ b/examples/docs/agent_basic.py @@ -19,7 +19,7 @@ agent_state = client.agents.create( ], # set automatic defaults for LLM/embedding config model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) print(f"Created agent with name {agent_state.name} and unique ID {agent_state.id}") diff --git a/examples/docs/example.py b/examples/docs/example.py index 38f924ba..2f3530a2 100644 --- a/examples/docs/example.py +++ b/examples/docs/example.py @@ -21,7 +21,7 @@ agent = client.agents.create( ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) print(f"Created agent with name {agent.name}") @@ -120,7 +120,7 @@ for chunk in stream: agent_copy = client.agents.create( model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) block = client.agents.blocks.retrieve(agent.id, block_label="human") agent_copy = client.agents.blocks.attach(agent_copy.id, block.id) diff --git a/examples/docs/tools.py b/examples/docs/tools.py index fdedef1b..728c8036 100644 --- a/examples/docs/tools.py +++ b/examples/docs/tools.py @@ -46,7 +46,7 @@ agent_state = client.agents.create( ], # set automatic defaults for LLM/embedding config model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", # create the agent with an additional tool tool_ids=[tool.id], tool_rules=[ @@ -89,7 +89,7 @@ agent_state = client.agents.create( ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", include_base_tools=False, tool_ids=[tool.id, send_message_tool.id], ) diff --git a/examples/mcp_example.py b/examples/mcp_example.py index a12c3faf..9352a1a3 100644 --- a/examples/mcp_example.py +++ b/examples/mcp_example.py @@ -26,7 +26,7 @@ agent = client.agents.create( } ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", tool_ids=[mcp_tool.id] ) print(f"Created agent id {agent.id}") diff --git a/examples/sleeptime/sleeptime_source_example.py b/examples/sleeptime/sleeptime_source_example.py index efc9e8d3..c782060e 100644 --- a/examples/sleeptime/sleeptime_source_example.py +++ b/examples/sleeptime/sleeptime_source_example.py @@ -30,7 +30,7 @@ source_name = "employee_handbook" source = client.sources.create( name=source_name, description="Provides reference information for the employee handbook", - embedding="openai/text-embedding-ada-002" # must match agent + embedding="openai/text-embedding-3-small" # must match agent ) # attach the source to the agent client.agents.sources.attach( diff --git a/letta/constants.py b/letta/constants.py index aab388bf..3c5cb095 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -65,7 +65,7 @@ DEFAULT_EMBEDDING_CHUNK_SIZE = 300 # tokenizers EMBEDDING_TO_TOKENIZER_MAP = { - "text-embedding-ada-002": "cl100k_base", + "text-embedding-3-small": "cl100k_base", } EMBEDDING_TO_TOKENIZER_DEFAULT = "cl100k_base" diff --git a/letta/schemas/embedding_config.py b/letta/schemas/embedding_config.py index 25162d0b..5ea23a02 100644 --- a/letta/schemas/embedding_config.py +++ b/letta/schemas/embedding_config.py @@ -63,6 +63,14 @@ class EmbeddingConfig(BaseModel): embedding_dim=1536, embedding_chunk_size=300, ) + if model_name == "text-embedding-3-small" and provider == "openai": + return cls( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=2000, + embedding_chunk_size=300, + ) elif model_name == "letta": return cls( embedding_endpoint="https://embeddings.memgpt.ai", diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 89bfebc5..28179cbf 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -21,9 +21,14 @@ from letta.server.server import SyncServer from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder from letta.services.file_processor.file_processor import FileProcessor -from letta.services.file_processor.file_types import get_allowed_media_types, get_extension_to_mime_type_map, register_mime_types +from letta.services.file_processor.file_types import ( + get_allowed_media_types, + get_extension_to_mime_type_map, + is_simple_text_mime_type, + register_mime_types, +) from letta.services.file_processor.parser.mistral_parser import MistralFileParser -from letta.settings import model_settings, settings +from letta.settings import settings from letta.utils import safe_create_task, sanitize_filename logger = get_logger(__name__) @@ -228,20 +233,26 @@ async def upload_file_to_source( agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) # NEW: Cloud based file processing - if settings.mistral_api_key and model_settings.openai_api_key: - logger.info("Running experimental cloud based file processing...") - safe_create_task( - load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor), - logger=logger, - label="file_processor.process", - ) - else: - # create background tasks - safe_create_task( - load_file_to_source_async(server, source_id=source.id, filename=file.filename, job_id=job.id, bytes=content, actor=actor), - logger=logger, - label="load_file_to_source_async", + # Determine file's MIME type + file_mime_type = mimetypes.guess_type(file.filename)[0] or "application/octet-stream" + + # Check if it's a simple text file + is_simple_file = is_simple_text_mime_type(file_mime_type) + + # For complex files, require Mistral API key + if not is_simple_file and not settings.mistral_api_key: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Mistral API key is required to process this file type {file_mime_type}. Please configure your Mistral API key to upload complex file formats.", ) + + # Use cloud processing for all files (simple files always, complex files with Mistral key) + logger.info("Running experimental cloud based file processing...") + safe_create_task( + load_file_to_source_cloud(server, agent_states, content, file, job, source_id, actor), + logger=logger, + label="file_processor.process", + ) safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async") return job diff --git a/letta/services/file_processor/embedder/openai_embedder.py b/letta/services/file_processor/embedder/openai_embedder.py index 5b1a2e72..cd2110ff 100644 --- a/letta/services/file_processor/embedder/openai_embedder.py +++ b/letta/services/file_processor/embedder/openai_embedder.py @@ -16,7 +16,7 @@ class OpenAIEmbedder: """OpenAI-based embedding generation""" def __init__(self, embedding_config: Optional[EmbeddingConfig] = None): - self.embedding_config = embedding_config or EmbeddingConfig.default_config(provider="openai") + self.embedding_config = embedding_config or EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai") # TODO: Unify to global OpenAI client self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key) diff --git a/letta/services/file_processor/parser/mistral_parser.py b/letta/services/file_processor/parser/mistral_parser.py index 43d74872..6b2b13f6 100644 --- a/letta/services/file_processor/parser/mistral_parser.py +++ b/letta/services/file_processor/parser/mistral_parser.py @@ -20,11 +20,10 @@ class MistralFileParser(FileParser): async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse: """Extract text using Mistral OCR or shortcut for plain text.""" try: - logger.info(f"Extracting text using Mistral OCR model: {self.model}") - # TODO: Kind of hacky...we try to exit early here? # TODO: Create our internal file parser representation we return instead of OCRResponse if is_simple_text_mime_type(mime_type): + logger.info(f"Extracting text directly (no Mistral): {self.model}") text = content.decode("utf-8", errors="replace") return OCRResponse( model=self.model, @@ -43,6 +42,7 @@ class MistralFileParser(FileParser): base64_encoded_content = base64.b64encode(content).decode("utf-8") document_url = f"data:{mime_type};base64,{base64_encoded_content}" + logger.info(f"Extracting text using Mistral OCR model: {self.model}") async with Mistral(api_key=settings.mistral_api_key) as mistral: ocr_response = await mistral.ocr.process_async( model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False diff --git a/paper_experiments/doc_qa_task/doc_qa.py b/paper_experiments/doc_qa_task/doc_qa.py index e07060d1..af2cf86a 100644 --- a/paper_experiments/doc_qa_task/doc_qa.py +++ b/paper_experiments/doc_qa_task/doc_qa.py @@ -7,7 +7,7 @@ asked to use the provided documents to answer the question. Similar to Liu et al we evaluate reader accuracy as the number of retrieved documents K increases. In our evaluation setup, both the fixed-context baselines and Letta use the same retriever, which selects the top K documents according using Faiss efficient similarity search (Johnson et al., 2019) (which corresponds to -approximate nearest neighbor search) on OpenAI's text-embedding-ada-002 embeddings. In +approximate nearest neighbor search) on OpenAI's text-embedding-3-small embeddings. In Letta, the entire document set is loaded into archival storage, and the retriever naturally emerges via the archival storage search functionality (which performs embedding-based similarity search). In the fixed-context baselines, the top-K documents are fetched using the retriever independently diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 8e87643c..db0a50b9 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -76,7 +76,7 @@ async def test_sleeptime_group_chat(server, actor): ], # model="openai/gpt-4o-mini", model="anthropic/claude-3-5-sonnet-20240620", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", enable_sleeptime=True, ), actor=actor, @@ -190,7 +190,7 @@ async def test_sleeptime_group_chat_v2(server, actor): ], # model="openai/gpt-4o-mini", model="anthropic/claude-3-5-sonnet-20240620", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", enable_sleeptime=True, ), actor=actor, @@ -310,7 +310,7 @@ async def test_sleeptime_removes_redundant_information(server, actor): ), ], model="anthropic/claude-3-5-sonnet-20240620", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", enable_sleeptime=True, ), actor=actor, @@ -389,7 +389,7 @@ async def test_sleeptime_edit(server, actor): ), ], model="anthropic/claude-3-5-sonnet-20240620", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", enable_sleeptime=True, ), actor=actor, diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index d25ccabe..1c61dcec 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -190,7 +190,7 @@ def voice_agent(server, actor, roll_dice_tool): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", enable_sleeptime=True, tool_ids=[roll_dice_tool.id, run_code_tool.id], ), @@ -279,7 +279,7 @@ async def test_model_compatibility(model, message, server, server_url, actor, ro ), ], model=model, - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", enable_sleeptime=True, tool_ids=[roll_dice_tool.id, run_code_tool.id], ), diff --git a/tests/sdk/agents_test.py b/tests/sdk/agents_test.py index 9da54d46..74830702 100644 --- a/tests/sdk/agents_test.py +++ b/tests/sdk/agents_test.py @@ -1,7 +1,7 @@ from conftest import create_test_module AGENTS_CREATE_PARAMS = [ - ("caren_agent", {"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-ada-002"}, {}, None), + ("caren_agent", {"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-3-small"}, {}, None), ] AGENTS_MODIFY_PARAMS = [ diff --git a/tests/sdk/conftest.py b/tests/sdk/conftest.py index 36317fb4..b05eb3d7 100644 --- a/tests/sdk/conftest.py +++ b/tests/sdk/conftest.py @@ -87,7 +87,7 @@ def create_test_module( agent = client.agents.create( name="caren_agent", model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) # Add finalizer to ensure cleanup happens in the right order diff --git a/tests/test_cli.py b/tests/test_cli.py index c0b8525b..fd62e396 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -57,7 +57,7 @@ def test_letta_run_create_new_agent(swap_letta_config): # Optional: Embedding model selection try: child.expect("Select embedding model:", timeout=20) - child.sendline("text-embedding-ada-002") + child.sendline("text-embedding-3-small") except (pexpect.TIMEOUT, pexpect.EOF): print("[WARNING] Embedding model selection step was skipped.") diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index e118f6a3..2dda36a9 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -66,7 +66,7 @@ def participant_agents(server, actor): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -80,7 +80,7 @@ def participant_agents(server, actor): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -94,7 +94,7 @@ def participant_agents(server, actor): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -108,7 +108,7 @@ def participant_agents(server, actor): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -137,7 +137,7 @@ def manager_agent(server, actor): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -350,7 +350,7 @@ async def test_supervisor(server, actor, participant_agents): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -420,7 +420,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen request=CreateAgent( name="shaggy", model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index f7d4a89b..0e81f40b 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -48,7 +48,7 @@ def agent(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) yield agent_state @@ -74,7 +74,7 @@ def test_shared_blocks(client: LettaSDKClient): ], block_ids=[block.id], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) agent_state2 = client.agents.create( name="agent2", @@ -86,7 +86,7 @@ def test_shared_blocks(client: LettaSDKClient): ], block_ids=[block.id], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) # update memory @@ -132,7 +132,7 @@ def test_read_only_block(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) # make sure agent cannot update read-only block @@ -175,7 +175,7 @@ def test_add_and_manage_tags_for_agent(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) assert len(agent.tags) == 0 @@ -227,7 +227,7 @@ def test_agent_tags(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", tags=["test", "agent1", "production"], ) @@ -239,7 +239,7 @@ def test_agent_tags(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", tags=["test", "agent2", "development"], ) @@ -251,7 +251,7 @@ def test_agent_tags(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", tags=["test", "agent3", "production"], ) @@ -556,7 +556,7 @@ def test_agent_creation(client: LettaSDKClient): name=f"test_agent_{str(uuid.uuid4())}", memory_blocks=[sleeptime_persona_block, mindy_block], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", tool_ids=[tool1.id, tool2.id], include_base_tools=False, tags=["test"], @@ -595,7 +595,7 @@ def test_many_blocks(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", include_base_tools=False, tags=["test"], ) @@ -612,7 +612,7 @@ def test_many_blocks(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", include_base_tools=False, tags=["test"], ) diff --git a/tests/test_server.py b/tests/test_server.py index 6a13f672..23756386 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -330,7 +330,7 @@ def agent_id(server, user_id, base_tools): tool_ids=[t.id for t in base_tools], memory_blocks=[], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -350,7 +350,7 @@ def other_agent_id(server, user_id, base_tools): tool_ids=[t.id for t in base_tools], memory_blocks=[], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -523,7 +523,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): name="nonexistent_tools_agent", memory_blocks=[], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=user, ) @@ -577,7 +577,7 @@ async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop request=CreateAgent( model="caren/my-custom-model", context_window_limit=context_window_override, - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=user, ) @@ -914,7 +914,7 @@ async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tool CreateBlock(label="persona", value="My name is Alice."), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) @@ -965,7 +965,7 @@ def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools): CreateBlock(label="persona", value="My name is Alice."), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", include_base_tools=True, ), actor=actor, @@ -982,7 +982,7 @@ def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_m tool_ids=[t.id for t in base_tools + base_memory_tools], memory_blocks=[], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", include_base_tools=False, ), actor=actor, @@ -1005,7 +1005,7 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to CreateBlock(label="persona", value="My name is Alice."), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", include_base_tools=False, ), actor=actor, @@ -1035,7 +1035,7 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to "embedding_config": { "embedding_endpoint_type": "openai", "embedding_endpoint": "https://api.openai.com/v1", - "embedding_model": "text-embedding-ada-002", + "embedding_model": "text-embedding-3-small", "embedding_dim": 1536, "embedding_chunk_size": 300, "azure_endpoint": None, @@ -1086,7 +1086,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str, memory_blocks=[], model="caren-anthropic/claude-3-5-sonnet-20240620", context_window_limit=100000, - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ), actor=actor, ) diff --git a/tests/test_sources.py b/tests/test_sources.py index 15580189..266fdd39 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -68,7 +68,7 @@ def agent_state(client: LettaSDKClient): ), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", tool_ids=[open_file_tool.id, close_file_tool.id, search_files_tool.id, grep_tool.id], ) yield agent_state @@ -85,7 +85,7 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient): CreateBlock(label="human", value="username: sarah"), ], model="openai/gpt-4o-mini", - embedding="openai/text-embedding-ada-002", + embedding="openai/text-embedding-3-small", ) # Helper function to get file tools from agent @@ -106,14 +106,14 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient): assert_no_file_tools(agent) # Create and attach first source - source_1 = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source_1 = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 agent = client.agents.sources.attach(source_id=source_1.id, agent_id=agent.id) assert_file_tools_present(agent, set(FILES_TOOLS)) # Create and attach second source - source_2 = client.sources.create(name="another_test_source", embedding="openai/text-embedding-ada-002") + source_2 = client.sources.create(name="another_test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 2 agent = client.agents.sources.attach(source_id=source_2.id, agent_id=agent.id) @@ -152,7 +152,7 @@ def test_file_upload_creates_source_blocks_correctly( expected_label_regex: str, ): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 # Attach @@ -196,7 +196,7 @@ def test_file_upload_creates_source_blocks_correctly( def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 # Load files into the source @@ -240,7 +240,7 @@ def test_attach_existing_files_creates_source_blocks_correctly(client: LettaSDKC def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 # Attach @@ -279,7 +279,7 @@ def test_delete_source_removes_source_blocks_correctly(client: LettaSDKClient, a def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") sources_list = client.sources.list() assert len(sources_list) == 1 @@ -388,7 +388,7 @@ def test_agent_uses_open_close_file_correctly(client: LettaSDKClient, agent_stat def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") sources_list = client.sources.list() assert len(sources_list) == 1 @@ -440,7 +440,7 @@ def test_agent_uses_search_files_correctly(client: LettaSDKClient, agent_state: def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentState): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") sources_list = client.sources.list() assert len(sources_list) == 1 @@ -490,7 +490,7 @@ def test_agent_uses_grep_correctly(client: LettaSDKClient, agent_state: AgentSta def test_view_ranges_have_metadata(client: LettaSDKClient, agent_state: AgentState): # Create a new source - source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") sources_list = client.sources.list() assert len(sources_list) == 1