diff --git a/letta/server/server.py b/letta/server/server.py index 2fc78496..106663f2 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -94,7 +94,7 @@ from letta.services.user_manager import UserManager from letta.settings import model_settings, settings, tool_settings from letta.sleeptime_agent import SleeptimeAgent from letta.tracing import trace_method -from letta.utils import get_friendly_error_msg +from letta.utils import get_friendly_error_msg, make_key config = LettaConfig.load() logger = get_logger(__name__) @@ -346,6 +346,10 @@ class SyncServer(Server): logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") + # TODO: Remove these in memory caches + self._llm_config_cache = {} + self._embedding_config_cache = {} + def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) @@ -696,6 +700,18 @@ class SyncServer(Server): command = command[1:] # strip the prefix return self._command(user_id=user_id, agent_id=agent_id, command=command) + def get_cached_llm_config(self, **kwargs): + key = make_key(**kwargs) + if key not in self._llm_config_cache: + self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs) + return self._llm_config_cache[key] + + def get_cached_embedding_config(self, **kwargs): + key = make_key(**kwargs) + if key not in self._embedding_config_cache: + self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs) + return self._embedding_config_cache[key] + def create_agent( self, request: CreateAgent, @@ -706,7 +722,7 @@ class SyncServer(Server): if request.llm_config is None: if request.model is None: raise ValueError("Must specify either model or llm_config in request") - request.llm_config = self.get_llm_config_from_handle( + request.llm_config = self.get_cached_llm_config( handle=request.model, context_window_limit=request.context_window_limit, max_tokens=request.max_tokens, @@ -717,8 +733,9 @@ class SyncServer(Server): if request.embedding_config is None: if request.embedding is None: raise ValueError("Must specify either embedding or embedding_config in request") - request.embedding_config = self.get_embedding_config_from_handle( - handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE + request.embedding_config = self.get_cached_embedding_config( + handle=request.embedding, + embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, ) main_agent = self.agent_manager.create_agent( diff --git a/letta/utils.py b/letta/utils.py index 33927f83..57ff86d0 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1070,3 +1070,7 @@ def log_telemetry(logger: Logger, event: str, **kwargs): timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S,%f UTC") # More readable timestamp extra_data = " | ".join(f"{key}={value}" for key, value in kwargs.items() if value is not None) logger.info(f"[{timestamp}] EVENT: {event} | {extra_data}") + + +def make_key(*args, **kwargs): + return str((args, tuple(sorted(kwargs.items())))) diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py index cd11676e..1ca45ec1 100644 --- a/tests/integration_test_experimental.py +++ b/tests/integration_test_experimental.py @@ -465,3 +465,41 @@ def test_anthropic_streaming(client: Letta): ) print(list(response)) + + +import time + + +def test_create_agents_telemetry(client: Letta): + start_total = time.perf_counter() + + # delete any existing worker agents + start_delete = time.perf_counter() + workers = client.agents.list(tags=["worker"]) + for worker in workers: + client.agents.delete(agent_id=worker.id) + end_delete = time.perf_counter() + print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s") + + # create worker agents + num_workers = 100 + agent_times = [] + for idx in range(num_workers): + start = time.perf_counter() + client.agents.create( + name=f"worker_{idx}", + include_base_tools=True, + model="anthropic/claude-3-5-sonnet-20241022", + embedding="letta/letta-free", + ) + end = time.perf_counter() + duration = end - start + agent_times.append(duration) + print(f"[telemetry] Created worker_{idx} in {duration:.2f}s") + + total_duration = time.perf_counter() - start_total + avg_duration = sum(agent_times) / len(agent_times) + + print(f"[telemetry] Total time to create {num_workers} agents: {total_duration:.2f}s") + print(f"[telemetry] Average agent creation time: {avg_duration:.2f}s") + print(f"[telemetry] Fastest agent: {min(agent_times):.2f}s, Slowest agent: {max(agent_times):.2f}s")