fix: add LLMCallType enum and ensure call_type is set on all provider traces (#9258)

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
Sarah Wooders
2026-02-03 17:03:23 -08:00
committed by Caren Thomas
parent 96c4b7175e
commit eaf64fb510
17 changed files with 72 additions and 37 deletions

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import AsyncGenerator
from letta.llm_api.llm_client_base import LLMClientBase
from letta.schemas.enums import LLMCallType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.llm_config import LLMConfig
@@ -24,6 +25,7 @@ class LettaLLMAdapter(ABC):
self,
llm_client: LLMClientBase,
llm_config: LLMConfig,
call_type: LLMCallType,
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
@@ -32,6 +34,7 @@ class LettaLLMAdapter(ABC):
) -> None:
self.llm_client: LLMClientBase = llm_client
self.llm_config: LLMConfig = llm_config
self.call_type: LLMCallType = call_type
self.agent_id: str | None = agent_id
self.agent_tags: list[str] | None = agent_tags
self.run_id: str | None = run_id

View File

@@ -127,7 +127,7 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -6,7 +6,7 @@ from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInt
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
from letta.llm_api.llm_client_base import LLMClientBase
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
from letta.schemas.enums import ProviderType
from letta.schemas.enums import LLMCallType, ProviderType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.provider_trace import ProviderTrace
@@ -30,13 +30,14 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
self,
llm_client: LLMClientBase,
llm_config: LLMConfig,
call_type: LLMCallType,
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
) -> None:
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
super().__init__(llm_client, llm_config, call_type=call_type, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
async def invoke_llm(
@@ -205,7 +206,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -2,6 +2,7 @@ from typing import AsyncGenerator
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.schemas.enums import LLMCallType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
@@ -45,7 +46,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -254,7 +254,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
call_type=self.call_type,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,

View File

@@ -8,7 +8,7 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.prompts.gpt_system import get_system_text
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.enums import MessageRole
from letta.schemas.enums import LLMCallType, MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
@@ -92,7 +92,7 @@ class EphemeralSummaryAgent(BaseAgent):
telemetry_manager=TelemetryManager(),
agent_id=self.agent_id,
agent_tags=agent_state.tags,
call_type="summarization",
call_type=LLMCallType.summarization,
)
response_data = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
response = await llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)

View File

@@ -35,7 +35,7 @@ from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.otel.tracing import log_event, trace_method, tracer
from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import JobStatus, ProviderType, StepStatus, ToolType
from letta.schemas.enums import JobStatus, LLMCallType, ProviderType, StepStatus, ToolType
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
@@ -420,7 +420,7 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
@@ -774,7 +774,7 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
@@ -1252,7 +1252,7 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
@@ -1486,7 +1486,7 @@ class LettaAgent(BaseAgent):
agent_tags=agent_state.tags,
run_id=self.current_run_id,
step_id=step_metrics.id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
)
response = await llm_client.request_async_with_telemetry(request_data, agent_state.llm_config)
@@ -1559,7 +1559,7 @@ class LettaAgent(BaseAgent):
agent_tags=agent_state.tags,
run_id=self.current_run_id,
step_id=step_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
)
# Attempt LLM request with telemetry wrapper

View File

@@ -31,7 +31,7 @@ from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method, tracer
from letta.prompts.prompt_generator import PromptGenerator
from letta.schemas.agent import AgentState, UpdateAgent
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus, StepStatus
from letta.schemas.enums import AgentType, LLMCallType, MessageStreamStatus, RunStatus, StepStatus
from letta.schemas.letta_message import LettaMessage, MessageType
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import ClientToolSchema
@@ -158,6 +158,8 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter=LettaLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
org_id=self.actor.organization_id,
user_id=self.actor.id,
@@ -216,6 +218,7 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter=LettaLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -305,6 +308,7 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter = LettaLLMStreamAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -315,6 +319,7 @@ class LettaAgentV2(BaseAgentV2):
llm_adapter = LettaLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,

View File

@@ -28,7 +28,7 @@ from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.enums import LLMCallType, MessageRole
from letta.schemas.letta_message import (
ApprovalReturn,
CompactionStats,
@@ -209,6 +209,7 @@ class LettaAgentV3(LettaAgentV2):
llm_adapter=SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -356,6 +357,7 @@ class LettaAgentV3(LettaAgentV2):
llm_adapter = SimpleLLMStreamAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
@@ -366,6 +368,7 @@ class LettaAgentV3(LettaAgentV2):
llm_adapter = SimpleLLMRequestAdapter(
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
call_type=LLMCallType.agent_step,
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,

View File

@@ -23,7 +23,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.orm.user import User
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import ProviderCategory
from letta.schemas.enums import LLMCallType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
@@ -245,7 +245,7 @@ def create(
request_json=prepare_openai_payload(data),
response_json=response.model_json_schema(),
step_id=step_id,
call_type="agent_step",
call_type=LLMCallType.agent_step,
),
)

View File

@@ -10,7 +10,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.errors import ErrorCode, LLMConnectionError, LLMError
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType, ProviderCategory
from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
@@ -229,6 +229,7 @@ class LLMClientBase:
request_json=request_data,
response_json=response_data,
step_id=step_id,
call_type=LLMCallType.agent_step,
),
)
log_event(name="llm_response_received", attributes=response_data)
@@ -262,6 +263,7 @@ class LLMClientBase:
request_json=request_data,
response_json=response_data,
step_id=step_id,
call_type=LLMCallType.agent_step,
),
)

View File

@@ -96,6 +96,14 @@ class ProviderCategory(str, Enum):
byok = "byok"
class LLMCallType(str, Enum):
"""Type of LLM call for telemetry tracking."""
agent_step = "agent_step"
summarization = "summarization"
tool_generation = "tool_generation"
class MessageRole(str, Enum):
assistant = "assistant"
user = "user"

View File

@@ -26,7 +26,7 @@ from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.prompts.gpt_system import get_system_text
from letta.schemas.enums import AgentType, MessageRole, ToolType
from letta.schemas.enums import AgentType, LLMCallType, MessageRole, ToolType
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
@@ -956,7 +956,7 @@ async def generate_tool_from_prompt(
llm_client.set_telemetry_context(
telemetry_manager=TelemetryManager(),
call_type="tool_generation",
call_type=LLMCallType.tool_generation,
)
response_data = await llm_client.request_async_with_telemetry(request_data, llm_config)
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)

View File

@@ -16,7 +16,7 @@ from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.prompts import gpt_summarize
from letta.schemas.enums import AgentType, MessageRole, ProviderType
from letta.schemas.enums import AgentType, LLMCallType, MessageRole, ProviderType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
@@ -482,7 +482,7 @@ async def simple_summary(
agent_tags=agent_tags,
run_id=run_id,
step_id=step_id,
call_type="summarization",
call_type=LLMCallType.summarization,
org_id=actor.organization_id if actor else None,
user_id=actor.id if actor else None,
compaction_settings=compaction_settings,

View File

@@ -5,6 +5,7 @@ import pytest
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.errors import ContextWindowExceededError, LLMConnectionError, LLMServerError
from letta.llm_api.anthropic_client import AnthropicClient
from letta.schemas.enums import LLMCallType
from letta.schemas.llm_config import LLMConfig
@@ -42,7 +43,7 @@ async def test_letta_llm_stream_adapter_converts_anthropic_streaming_api_status_
llm_client = AnthropicClient()
llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config, call_type=LLMCallType.agent_step)
gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True)
with pytest.raises(LLMServerError):
@@ -83,7 +84,7 @@ async def test_letta_llm_stream_adapter_converts_anthropic_413_request_too_large
llm_client = AnthropicClient()
llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config, call_type=LLMCallType.agent_step)
gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True)
with pytest.raises(ContextWindowExceededError):
@@ -117,7 +118,7 @@ async def test_letta_llm_stream_adapter_converts_httpx_read_error(monkeypatch):
llm_client = AnthropicClient()
llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config, call_type=LLMCallType.agent_step)
gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True)
with pytest.raises(LLMConnectionError):
@@ -151,7 +152,7 @@ async def test_letta_llm_stream_adapter_converts_httpx_write_error(monkeypatch):
llm_client = AnthropicClient()
llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config, call_type=LLMCallType.agent_step)
gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True)
with pytest.raises(LLMConnectionError):

View File

@@ -198,6 +198,7 @@ class TestAdapterTelemetryAttributes:
"""Verify base LettaLLMAdapter has telemetry attributes."""
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
from letta.llm_api.llm_client import LLMClient
from letta.schemas.enums import LLMCallType
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
@@ -212,6 +213,7 @@ class TestAdapterTelemetryAttributes:
adapter = TestAdapter(
llm_client=mock_client,
llm_config=mock_llm_config,
call_type=LLMCallType.agent_step,
agent_id=agent_id,
agent_tags=agent_tags,
run_id=run_id,
@@ -220,11 +222,13 @@ class TestAdapterTelemetryAttributes:
assert adapter.agent_id == agent_id
assert adapter.agent_tags == agent_tags
assert adapter.run_id == run_id
assert adapter.call_type == LLMCallType.agent_step
def test_request_adapter_inherits_telemetry_attributes(self, mock_llm_config):
"""Verify LettaLLMRequestAdapter inherits telemetry attributes."""
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
from letta.llm_api.llm_client import LLMClient
from letta.schemas.enums import LLMCallType
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
@@ -235,6 +239,7 @@ class TestAdapterTelemetryAttributes:
adapter = LettaLLMRequestAdapter(
llm_client=mock_client,
llm_config=mock_llm_config,
call_type=LLMCallType.agent_step,
agent_id=agent_id,
agent_tags=agent_tags,
run_id=run_id,
@@ -248,6 +253,7 @@ class TestAdapterTelemetryAttributes:
"""Verify LettaLLMStreamAdapter inherits telemetry attributes."""
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.llm_api.llm_client import LLMClient
from letta.schemas.enums import LLMCallType
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
@@ -258,6 +264,7 @@ class TestAdapterTelemetryAttributes:
adapter = LettaLLMStreamAdapter(
llm_client=mock_client,
llm_config=mock_llm_config,
call_type=LLMCallType.agent_step,
agent_id=agent_id,
agent_tags=agent_tags,
run_id=run_id,
@@ -272,13 +279,14 @@ class TestAdapterTelemetryAttributes:
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.llm_api.llm_client import LLMClient
from letta.schemas.enums import LLMCallType
mock_client = LLMClient.create(provider_type="openai", put_inner_thoughts_first=True)
request_adapter = LettaLLMRequestAdapter(llm_client=mock_client, llm_config=mock_llm_config)
stream_adapter = LettaLLMStreamAdapter(llm_client=mock_client, llm_config=mock_llm_config)
request_adapter = LettaLLMRequestAdapter(llm_client=mock_client, llm_config=mock_llm_config, call_type=LLMCallType.agent_step)
stream_adapter = LettaLLMStreamAdapter(llm_client=mock_client, llm_config=mock_llm_config, call_type=LLMCallType.agent_step)
for attr in ["agent_id", "agent_tags", "run_id"]:
for attr in ["agent_id", "agent_tags", "run_id", "call_type"]:
assert hasattr(request_adapter, attr), f"LettaLLMRequestAdapter missing {attr}"
assert hasattr(stream_adapter, attr), f"LettaLLMStreamAdapter missing {attr}"

View File

@@ -24,7 +24,7 @@ from letta.errors import LLMAuthenticationError
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.openai_client import OpenAIClient
from letta.schemas.enums import AgentType, MessageRole
from letta.schemas.enums import AgentType, LLMCallType, MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
@@ -156,6 +156,7 @@ async def test_openai_usage_via_adapter():
adapter = SimpleLLMRequestAdapter(
llm_client=client,
llm_config=llm_config,
call_type=LLMCallType.agent_step,
)
messages = _build_simple_messages("Say hello in exactly 5 words.")
@@ -209,6 +210,7 @@ async def test_anthropic_usage_via_adapter():
adapter = SimpleLLMRequestAdapter(
llm_client=client,
llm_config=llm_config,
call_type=LLMCallType.agent_step,
)
# Anthropic requires a system message first
@@ -262,6 +264,7 @@ async def test_gemini_usage_via_adapter():
adapter = SimpleLLMRequestAdapter(
llm_client=client,
llm_config=llm_config,
call_type=LLMCallType.agent_step,
)
messages = _build_simple_messages("Say hello in exactly 5 words.")
@@ -307,7 +310,7 @@ async def test_openai_prefix_caching_via_adapter():
llm_config = LLMConfig.default_config("gpt-4o-mini")
# First request - should populate the cache
adapter1 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config)
adapter1 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config, call_type=LLMCallType.agent_step)
messages1 = [
Message(role=MessageRole.system, content=[TextContent(text=LARGE_SYSTEM_PROMPT)]),
Message(role=MessageRole.user, content=[TextContent(text="What is 2+2?")]),
@@ -323,7 +326,7 @@ async def test_openai_prefix_caching_via_adapter():
print(f"Request 1 - prompt={adapter1.usage.prompt_tokens}, cached={adapter1.usage.cached_input_tokens}")
# Second request - same system prompt, should hit cache
adapter2 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config)
adapter2 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config, call_type=LLMCallType.agent_step)
messages2 = [
Message(role=MessageRole.system, content=[TextContent(text=LARGE_SYSTEM_PROMPT)]),
Message(role=MessageRole.user, content=[TextContent(text="What is 3+3?")]),
@@ -368,7 +371,7 @@ async def test_anthropic_prefix_caching_via_adapter():
)
# First request
adapter1 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config)
adapter1 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config, call_type=LLMCallType.agent_step)
messages1 = [
Message(role=MessageRole.system, content=[TextContent(text=LARGE_SYSTEM_PROMPT)]),
Message(role=MessageRole.user, content=[TextContent(text="What is 2+2?")]),
@@ -386,7 +389,7 @@ async def test_anthropic_prefix_caching_via_adapter():
)
# Second request
adapter2 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config)
adapter2 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config, call_type=LLMCallType.agent_step)
messages2 = [
Message(role=MessageRole.system, content=[TextContent(text=LARGE_SYSTEM_PROMPT)]),
Message(role=MessageRole.user, content=[TextContent(text="What is 3+3?")]),
@@ -435,7 +438,7 @@ async def test_gemini_prefix_caching_via_adapter():
)
# First request
adapter1 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config)
adapter1 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config, call_type=LLMCallType.agent_step)
messages1 = [
Message(role=MessageRole.system, content=[TextContent(text=LARGE_SYSTEM_PROMPT)]),
Message(role=MessageRole.user, content=[TextContent(text="What is 2+2?")]),
@@ -451,7 +454,7 @@ async def test_gemini_prefix_caching_via_adapter():
print(f"Request 1 - prompt={adapter1.usage.prompt_tokens}, cached={adapter1.usage.cached_input_tokens}")
# Second request
adapter2 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config)
adapter2 = SimpleLLMRequestAdapter(llm_client=client, llm_config=llm_config, call_type=LLMCallType.agent_step)
messages2 = [
Message(role=MessageRole.system, content=[TextContent(text=LARGE_SYSTEM_PROMPT)]),
Message(role=MessageRole.user, content=[TextContent(text="What is 3+3?")]),